@@ -78,14 +78,30 @@ def l2norm_pytree(x):
7878
7979def activate_profiler (config ):
8080 if jax .process_index () == 0 and config .enable_profiler :
81- jax .profiler .start_trace (config .tensorboard_dir )
81+ # If tensorboard_dir is GCS, write profiler traces locally instead
82+ profiler_path = config .tensorboard_dir
83+ if config .tensorboard_dir .startswith ("gs://" ):
84+ profiler_path = "/tmp/profiler_traces"
85+ os .makedirs (profiler_path , exist_ok = True )
86+ max_logging .log (f"Profiler: saving traces locally to { profiler_path } (GCS paths not supported)" )
87+ jax .profiler .start_trace (profiler_path )
8288
8389
8490def deactivate_profiler (config ):
8591 if jax .process_index () == 0 and config .enable_profiler :
8692 jax .profiler .stop_trace ()
8793
8894
95+ def upload_profiler_traces (config ):
96+ """No-op for now - profiler traces are saved locally"""
97+ if jax .process_index () == 0 and config .enable_profiler :
98+ if config .tensorboard_dir .startswith ("gs://" ):
99+ max_logging .log ("Profiler traces saved to: /tmp/profiler_traces" )
100+ max_logging .log ("You can download them manually or use: gsutil -m rsync -r /tmp/profiler_traces/ " + config .tensorboard_dir .rstrip ("/" ) + "/" )
101+ else :
102+ max_logging .log (f"Profiler traces saved to: { config .tensorboard_dir } " )
103+
104+
89105def initialize_summary_writer (config ):
90106 return writer .SummaryWriter (config .tensorboard_dir ) if jax .process_index () == 0 else None
91107
@@ -94,7 +110,6 @@ def close_summary_writer(summary_writer):
94110 if jax .process_index () == 0 :
95111 summary_writer .close ()
96112
97-
98113def _prepare_metrics_for_json (metrics , step , run_name ):
99114 """Converts metric dictionary into json supported types (e.g. float)"""
100115 metrics_dict = {}
0 commit comments