Skip to content

Commit 3b35dd5

Browse files
committed
fix OOM problem
1 parent 5503f9c commit 3b35dd5

1 file changed

Lines changed: 24 additions & 9 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
342342
# Check if any evaluation was actually performed
343343
if eval_losses_by_timestep:
344344
mean_per_timestep = []
345-
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
345+
if jax.process_index() == 0:
346+
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
346347
for timestep, losses in sorted(eval_losses_by_timestep.items()):
347348
losses = jnp.array(losses)
348349
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
349350
mean_loss = jnp.mean(losses)
350-
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}")
351+
if jax.process_index() == 0:
352+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}")
351353
mean_per_timestep.append(mean_loss)
352354
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
353-
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
355+
if jax.process_index() == 0:
356+
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
354357
if writer:
355358
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
356359
else:
@@ -428,14 +431,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
428431

429432
# The loss function logic is identical to training. We are evaluating the model's
430433
# ability to perform its core training objective (e.g., denoising).
431-
def loss_fn(params):
434+
def loss_fn(params, latents, encoder_hidden_states, timesteps):
432435
# Reconstruct the model from its definition and parameters
433436
model = nnx.merge(state.graphdef, params, state.rest_of_state)
434437

435438
# Prepare inputs
436-
latents = data["latents"].astype(config.weights_dtype)
437-
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
438-
timesteps = data["timesteps"].astype("int64")
439+
# latents = data["latents"].astype(config.weights_dtype)
440+
# encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
441+
# timesteps = data["timesteps"].astype("int64")
439442

440443
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
441444
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
@@ -460,10 +463,22 @@ def loss_fn(params):
460463
# --- Key Difference from train_step ---
461464
# Directly compute the loss without calculating gradients.
462465
# The model's state.params are used but not updated.
463-
loss = loss_fn(state.params)
466+
bs = len(data["latents"])
467+
single_batch_size = min(8, config.global_batch_size_to_train_on)
468+
losses = jnp.zeros(bs)
469+
for i in range(0, bs, single_batch_size):
470+
start = i
471+
end = min(i + single_batch_size, bs)
472+
jax.debug.print("Eval step processing samples {start} to {end}", start=start, end=end)
473+
latents= data["latents"][start:end, :].astype(config.weights_dtype)
474+
encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype)
475+
timesteps = data["timesteps"][start:end].astype("int64")
476+
loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps)
477+
losses = losses.at[start:end].set(loss)
464478

465479
# Structure the metrics for logging and aggregation
466-
metrics = {"scalar": {"learning/eval_loss": loss}}
480+
metrics = {"scalar": {"learning/eval_loss": losses}}
481+
jax.debug.print("Eval step losses: {losses}", losses=losses)
467482

468483
# Return the computed metrics and the new RNG key for the next eval step
469484
return metrics, new_rng

0 commit comments

Comments
 (0)