Skip to content

Commit 0af353d

Browse files
authored
Remove jax.spmd_mode, its no longer supported (#179)
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent 34454fb commit 0af353d

1 file changed

Lines changed: 19 additions & 20 deletions

File tree

src/maxdiffusion/train_utils.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,27 +100,26 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
100100

101101
def write_metrics_to_tensorboard(writer, metrics, step, config):
102102
"""Writes metrics to tensorboard"""
103-
with jax.spmd_mode("allow_all"):
104-
if jax.process_index() == 0:
105-
for metric_name in metrics.get("scalar", []):
106-
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
107-
for metric_name in metrics.get("scalars", []):
108-
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
109-
110-
full_log = step % config.log_period == 0
111-
if jax.process_index() == 0:
112-
max_logging.log(
113-
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
114-
step,
115-
metrics["scalar"]["perf/step_time_seconds"],
116-
metrics["scalar"]["perf/per_device_tflops_per_sec"],
117-
float(metrics["scalar"]["learning/loss"]),
118-
)
119-
)
103+
if jax.process_index() == 0:
104+
for metric_name in metrics.get("scalar", []):
105+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
106+
for metric_name in metrics.get("scalars", []):
107+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
108+
109+
full_log = step % config.log_period == 0
110+
if jax.process_index() == 0:
111+
max_logging.log(
112+
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
113+
step,
114+
metrics["scalar"]["perf/step_time_seconds"],
115+
metrics["scalar"]["perf/per_device_tflops_per_sec"],
116+
float(metrics["scalar"]["learning/loss"]),
117+
)
118+
)
120119

121-
if full_log and jax.process_index() == 0:
122-
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
123-
writer.flush()
120+
if full_log and jax.process_index() == 0:
121+
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
122+
writer.flush()
124123

125124

126125
def get_params_to_save(params):

0 commit comments

Comments
 (0)