Skip to content

Commit 6004e92

Browse files
committed
update gcs bucket upload logic to ensure there's a local copy if it fails
1 parent 673f533 commit 6004e92

1 file changed

Lines changed: 24 additions & 1 deletion

File tree

src/maxdiffusion/max_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,36 @@ def l2norm_pytree(x):
8787

8888
def activate_profiler(config):
8989
if jax.process_index() == 0 and config.enable_profiler:
90-
jax.profiler.start_trace(config.tensorboard_dir)
90+
trace_dir = config.tensorboard_dir
91+
if trace_dir.startswith("gs://"):
92+
trace_dir = "/tmp/profiler_traces"
93+
os.makedirs(trace_dir, exist_ok=True)
94+
max_logging.log(f"Starting profiler trace in: {trace_dir}")
95+
jax.profiler.start_trace(trace_dir)
9196

9297

9398
def deactivate_profiler(config):
9499
if jax.process_index() == 0 and config.enable_profiler:
95100
jax.profiler.stop_trace()
96101

102+
trace_dir = config.tensorboard_dir
103+
if trace_dir.startswith("gs://"):
104+
local_dir = "/tmp/profiler_traces"
105+
if os.path.exists(local_dir):
106+
max_logging.log(f"Uploading profiler traces from {local_dir} to {trace_dir}...")
107+
client = storage.Client()
108+
bucket_name, prefix = parse_gcs_bucket_and_prefix(trace_dir)
109+
bucket = client.bucket(bucket_name)
110+
111+
for root, _, files in os.walk(local_dir):
112+
for file in files:
113+
local_file = os.path.join(root, file)
114+
rel_path = os.path.relpath(local_file, local_dir)
115+
blob_name = os.path.join(prefix, rel_path)
116+
blob = bucket.blob(blob_name)
117+
blob.upload_from_filename(local_file)
118+
max_logging.log(f"Uploaded {local_file} to gs://{bucket_name}/{blob_name}")
119+
97120

98121
def initialize_summary_writer(config):
99122
return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None

0 commit comments

Comments
 (0)