@@ -285,7 +285,7 @@ def model_fn(features, labels, mode, params):
285285
286286 # Computations to be executed on CPU, outside of the main TPU queues.
287287 def eval_metrics_host_call_fn (
288- features ,
288+ avg_stones , avg_stones_delta ,
289289 policy_output , value_output ,
290290 pi_tensor , value_tensor ,
291291 policy_cost , value_cost ,
@@ -309,8 +309,6 @@ def eval_metrics_host_call_fn(
309309
310310 value_cost_normalized = value_cost / params ['value_cost_weight' ]
311311 avg_value_observed = tf .reduce_mean (value_tensor )
312- avg_stones_black = tf .reduce_mean (tf .reduce_sum (features [:,:,:,1 ], [1 ,2 ]))
313- avg_stones_white = tf .reduce_mean (tf .reduce_sum (features [:,:,:,0 ], [1 ,2 ]))
314312
315313 with tf .variable_scope ('metrics' ):
316314 metric_ops = {
@@ -329,8 +327,7 @@ def eval_metrics_host_call_fn(
329327 'policy_target_top_1_confidence' : tf .metrics .mean (
330328 policy_target_top_1_confidence ),
331329 'avg_value_observed' : tf .metrics .mean (avg_value_observed ),
332- 'avg_stones_black' : tf .metrics .mean (avg_stones_black ),
333- 'avg_stones_white' : tf .metrics .mean (avg_stones_white ),
330+ 'avg_stones_black' : tf .metrics .mean (tf .reduce_mean (avg_stones )),
334331 }
335332
336333 if est_mode == tf .estimator .ModeKeys .EVAL :
@@ -348,6 +345,8 @@ def eval_metrics_host_call_fn(
348345 for metric_name , metric_op in metric_ops .items ():
349346 summary .scalar (metric_name , metric_op [1 ], step = eval_step )
350347
348+ summary .histogram ("avg_stones_white" , avg_stones_delta )
349+
351350 # Reset metrics occasionally so that they are mean of recent batches.
352351 reset_op = tf .variables_initializer (tf .local_variables ('metrics' ))
353352 cond_reset_op = tf .cond (
@@ -357,8 +356,14 @@ def eval_metrics_host_call_fn(
357356
358357 return summary .all_summary_ops () + [cond_reset_op ]
359358
359+ # compute here to avoid sending all of features to cpu.
360+ avg_stones_black = tf .reduce_sum (features [:,:,:,1 ], [1 ,2 ])
361+ avg_stones_white = tf .reduce_sum (features [:,:,:,0 ], [1 ,2 ])
362+ avg_stones = avg_stones_black + avg_stones_white
363+ avg_stones_delta = avg_stones_black - avg_stones_white
364+
360365 metric_args = [
361- features ,
366+ avg_stones , avg_stones_delta ,
362367 policy_output ,
363368 value_output ,
364369 labels ['pi_tensor' ],
0 commit comments