diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index e5a34a763..a040f85d2 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -100,27 +100,26 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step def write_metrics_to_tensorboard(writer, metrics, step, config): """Writes metrics to tensorboard""" - with jax.spmd_mode("allow_all"): - if jax.process_index() == 0: - for metric_name in metrics.get("scalar", []): - writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars", []): - writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) - - full_log = step % config.log_period == 0 - if jax.process_index() == 0: - max_logging.log( - "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( - step, - metrics["scalar"]["perf/step_time_seconds"], - metrics["scalar"]["perf/per_device_tflops_per_sec"], - float(metrics["scalar"]["learning/loss"]), - ) - ) + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + full_log = step % config.log_period == 0 + if jax.process_index() == 0: + max_logging.log( + "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( + step, + metrics["scalar"]["perf/step_time_seconds"], + metrics["scalar"]["perf/per_device_tflops_per_sec"], + float(metrics["scalar"]["learning/loss"]), + ) + ) - if full_log and jax.process_index() == 0: - max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") - writer.flush() + if full_log and jax.process_index() == 0: + max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") + writer.flush() def get_params_to_save(params):