Skip to content

Commit e6f495e

Browse files
committed
solve comment
1 parent e96302c commit e6f495e

2 files changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,6 @@ quantization_calibration_method: "absmax"
318318
eval_every: -1
319319
eval_data_dir: ""
320320
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321-
eval_max_number_of_samples_in_bucket: 60
321+
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
322322

323323
enable_ssim: True

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
472472
# --- Key Difference from train_step ---
473473
# Directly compute the loss without calculating gradients.
474474
# The model's state.params are used but not updated.
475+
# TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop
475476
bs = len(data["latents"])
476477
single_batch_size = config.global_batch_size_to_train_on
477478
losses = jnp.zeros(bs)

0 commit comments

Comments
 (0)