Skip to content

Commit 7293017

Browse files
committed
updated vae config logic to be the consistent, update xprof logic
1 parent dff5c30 commit 7293017

8 files changed

Lines changed: 31 additions & 2 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,19 @@ def block_if_jax(x):
313313
# Block until warmup completes
314314
jax.tree_util.tree_map(block_if_jax, warmup_videos)
315315

316+
# Warm up GCS connection by flushing writer before starting profiler
317+
if writer and jax.process_index() == 0:
318+
max_logging.log("Flushing writer to warm up GCS connection before profiler...")
319+
writer.flush()
320+
316321
s0 = time.perf_counter()
317322
max_utils.activate_profiler(config)
318323
max_logging.log(f"Profiler: starting profiled run with {steps_for_profile} steps")
319324
profiled_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile)
320325
# Wait for all computation to finish before stopping profiler
321326
jax.tree_util.tree_map(block_if_jax, profiled_videos)
322327
max_utils.deactivate_profiler(config)
328+
max_utils.upload_profiler_traces(config)
323329
generation_time_with_profiler = time.perf_counter() - s0
324330
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
325331
if writer and jax.process_index() == 0:

src/maxdiffusion/max_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,30 @@ def l2norm_pytree(x):
7878

7979
def activate_profiler(config):
8080
if jax.process_index() == 0 and config.enable_profiler:
81-
jax.profiler.start_trace(config.tensorboard_dir)
81+
# If tensorboard_dir is GCS, write profiler traces locally instead
82+
profiler_path = config.tensorboard_dir
83+
if config.tensorboard_dir.startswith("gs://"):
84+
profiler_path = "/tmp/profiler_traces"
85+
os.makedirs(profiler_path, exist_ok=True)
86+
max_logging.log(f"Profiler: saving traces locally to {profiler_path} (GCS paths not supported)")
87+
jax.profiler.start_trace(profiler_path)
8288

8389

8490
def deactivate_profiler(config):
8591
if jax.process_index() == 0 and config.enable_profiler:
8692
jax.profiler.stop_trace()
8793

8894

95+
def upload_profiler_traces(config):
96+
"""No-op for now - profiler traces are saved locally"""
97+
if jax.process_index() == 0 and config.enable_profiler:
98+
if config.tensorboard_dir.startswith("gs://"):
99+
max_logging.log("Profiler traces saved to: /tmp/profiler_traces")
100+
max_logging.log("You can download them manually or use: gsutil -m rsync -r /tmp/profiler_traces/ " + config.tensorboard_dir.rstrip("/") + "/")
101+
else:
102+
max_logging.log(f"Profiler traces saved to: {config.tensorboard_dir}")
103+
104+
89105
def initialize_summary_writer(config):
90106
return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None
91107

@@ -94,7 +110,6 @@ def close_summary_writer(summary_writer):
94110
if jax.process_index() == 0:
95111
summary_writer.close()
96112

97-
98113
def _prepare_metrics_for_json(metrics, step, run_name):
99114
"""Converts metric dictionary into json supported types (e.g. float)"""
100115
metrics_dict = {}

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _reshape_batch_dim_to_heads(tensor, heads):
112112
head_size = heads
113113
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
114114
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
115+
reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
115116
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
116117
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)
117118

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
5757
devices_array=common_components["devices_array"],
5858
mesh=common_components["mesh"],
5959
vae_mesh=common_components["vae_mesh"],
60+
vae_logical_axis_rules=common_components["vae_logical_axis_rules"],
6061
config=config,
6162
)
6263

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
7474
devices_array=common_components["devices_array"],
7575
mesh=common_components["mesh"],
7676
vae_mesh=common_components["vae_mesh"],
77+
vae_logical_axis_rules=common_components["vae_logical_axis_rules"],
7778
config=config,
7879
)
7980
return pipeline, low_noise_transformer, high_noise_transformer

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6161
scheduler_state=common_components["scheduler_state"],
6262
devices_array=common_components["devices_array"],
6363
mesh=common_components["mesh"],
64+
vae_mesh=common_components["vae_mesh"],
65+
vae_logical_axis_rules=common_components["vae_logical_axis_rules"],
6466
config=config,
6567
)
6668
return pipeline, transformer

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
7979
scheduler_state=common_components["scheduler_state"],
8080
devices_array=common_components["devices_array"],
8181
mesh=common_components["mesh"],
82+
vae_mesh=common_components["vae_mesh"],
83+
vae_logical_axis_rules=common_components["vae_logical_axis_rules"],
8284
config=config,
8385
)
8486
return pipeline, low_noise_transformer, high_noise_transformer

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
386386

387387
if self.config.enable_profiler and step == last_profiling_step:
388388
max_utils.deactivate_profiler(self.config)
389+
max_utils.upload_profiler_traces(self.config)
389390

390391
train_states[FLUX_STATE_KEY] = flux_state
391392
if len(times) > 0:

0 commit comments

Comments
 (0)