@@ -100,27 +100,26 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
100100
101101def write_metrics_to_tensorboard (writer , metrics , step , config ):
102102 """Writes metrics to tensorboard"""
103- with jax .spmd_mode ("allow_all" ):
104- if jax .process_index () == 0 :
105- for metric_name in metrics .get ("scalar" , []):
106- writer .add_scalar (metric_name , np .array (metrics ["scalar" ][metric_name ]), step )
107- for metric_name in metrics .get ("scalars" , []):
108- writer .add_scalars (metric_name , metrics ["scalars" ][metric_name ], step )
109-
110- full_log = step % config .log_period == 0
111- if jax .process_index () == 0 :
112- max_logging .log (
113- "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}" .format (
114- step ,
115- metrics ["scalar" ]["perf/step_time_seconds" ],
116- metrics ["scalar" ]["perf/per_device_tflops_per_sec" ],
117- float (metrics ["scalar" ]["learning/loss" ]),
118- )
119- )
103+ if jax .process_index () == 0 :
104+ for metric_name in metrics .get ("scalar" , []):
105+ writer .add_scalar (metric_name , np .array (metrics ["scalar" ][metric_name ]), step )
106+ for metric_name in metrics .get ("scalars" , []):
107+ writer .add_scalars (metric_name , metrics ["scalars" ][metric_name ], step )
108+
109+ full_log = step % config .log_period == 0
110+ if jax .process_index () == 0 :
111+ max_logging .log (
112+ "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}" .format (
113+ step ,
114+ metrics ["scalar" ]["perf/step_time_seconds" ],
115+ metrics ["scalar" ]["perf/per_device_tflops_per_sec" ],
116+ float (metrics ["scalar" ]["learning/loss" ]),
117+ )
118+ )
120119
121- if full_log and jax .process_index () == 0 :
122- max_logging .log (f"To see full metrics 'tensorboard --logdir={ config .tensorboard_dir } '" )
123- writer .flush ()
120+ if full_log and jax .process_index () == 0 :
121+ max_logging .log (f"To see full metrics 'tensorboard --logdir={ config .tensorboard_dir } '" )
122+ writer .flush ()
124123
125124
126125def get_params_to_save (params ):
0 commit comments