Skip to content

Commit 9b4ae33

Browse files
committed
improve speed on eval
1 parent eb7c473 commit 9b4ae33

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,15 +432,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
432432
"""
433433
Computes the evaluation loss for a single batch without updating model weights.
434434
"""
435-
_, new_rng = jax.random.split(rng, num=2)
436435

437436
# The loss function logic is identical to training. We are evaluating the model's
438437
# ability to perform its core training objective (e.g., denoising).
439-
def loss_fn(params, latents, encoder_hidden_states, timesteps):
438+
@jax.jit
439+
def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
440440
# Reconstruct the model from its definition and parameters
441441
model = nnx.merge(state.graphdef, params, state.rest_of_state)
442442

443-
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
443+
noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype)
444444
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
445445

446446
# Get the model's prediction
@@ -472,7 +472,8 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps):
472472
latents= data["latents"][start:end, :].astype(config.weights_dtype)
473473
encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype)
474474
timesteps = data["timesteps"][start:end].astype("int64")
475-
loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps)
475+
_, new_rng = jax.random.split(rng, num=2)
476+
loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps, new_rng)
476477
losses = losses.at[start:end].set(loss)
477478

478479
# Structure the metrics for logging and aggregation

0 commit comments

Comments
 (0)