@@ -432,15 +432,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
432432 """
433433 Computes the evaluation loss for a single batch without updating model weights.
434434 """
435- _ , new_rng = jax .random .split (rng , num = 2 )
436435
437436 # The loss function logic is identical to training. We are evaluating the model's
438437 # ability to perform its core training objective (e.g., denoising).
439- def loss_fn (params , latents , encoder_hidden_states , timesteps ):
438+ @jax .jit
439+ def loss_fn (params , latents , encoder_hidden_states , timesteps , rng ):
440440 # Reconstruct the model from its definition and parameters
441441 model = nnx .merge (state .graphdef , params , state .rest_of_state )
442442
443- noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
443+ noise = jax .random .normal (key = rng , shape = latents .shape , dtype = latents .dtype )
444444 noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
445445
446446 # Get the model's prediction
@@ -472,7 +472,8 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps):
472472 latents = data ["latents" ][start :end , :].astype (config .weights_dtype )
473473 encoder_hidden_states = data ["encoder_hidden_states" ][start :end , :].astype (config .weights_dtype )
474474 timesteps = data ["timesteps" ][start :end ].astype ("int64" )
475- loss = loss_fn (state .params , latents , encoder_hidden_states , timesteps )
475+ _ , new_rng = jax .random .split (rng , num = 2 )
476+ loss = loss_fn (state .params , latents , encoder_hidden_states , timesteps , new_rng )
476477 losses = losses .at [start :end ].set (loss )
477478
478479 # Structure the metrics for logging and aggregation
0 commit comments