@@ -87,13 +87,36 @@ def l2norm_pytree(x):
8787
8888def activate_profiler (config ):
8989 if jax .process_index () == 0 and config .enable_profiler :
90- jax .profiler .start_trace (config .tensorboard_dir )
90+ trace_dir = config .tensorboard_dir
91+ if trace_dir .startswith ("gs://" ):
92+ trace_dir = "/tmp/profiler_traces"
93+ os .makedirs (trace_dir , exist_ok = True )
94+ max_logging .log (f"Starting profiler trace in: { trace_dir } " )
95+ jax .profiler .start_trace (trace_dir )
9196
9297
9398def deactivate_profiler (config ):
9499 if jax .process_index () == 0 and config .enable_profiler :
95100 jax .profiler .stop_trace ()
96101
102+ trace_dir = config .tensorboard_dir
103+ if trace_dir .startswith ("gs://" ):
104+ local_dir = "/tmp/profiler_traces"
105+ if os .path .exists (local_dir ):
106+ max_logging .log (f"Uploading profiler traces from { local_dir } to { trace_dir } ..." )
107+ client = storage .Client ()
108+ bucket_name , prefix = parse_gcs_bucket_and_prefix (trace_dir )
109+ bucket = client .bucket (bucket_name )
110+
111+ for root , _ , files in os .walk (local_dir ):
112+ for file in files :
113+ local_file = os .path .join (root , file )
114+ rel_path = os .path .relpath (local_file , local_dir )
115+ blob_name = os .path .join (prefix , rel_path )
116+ blob = bucket .blob (blob_name )
117+ blob .upload_from_filename (local_file )
118+ max_logging .log (f"Uploaded { local_file } to gs://{ bucket_name } /{ blob_name } " )
119+
97120
98121def initialize_summary_writer (config ):
99122 return writer .SummaryWriter (config .tensorboard_dir ) if jax .process_index () == 0 else None
0 commit comments