Skip to content

Commit eb7c473

Browse files
committed
remove print log
1 parent aaaa094 commit eb7c473

2 files changed

Lines changed: 5 additions & 11 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,4 @@ eval_every: -1
319319
eval_data_dir: ""
320320
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321321
eval_max_number_of_samples_in_bucket: 60
322+
eval_max_processed_batch_size: 8 # This is the max batch size per device for eval step. If the global eval batch size is larger than this, the eval step will be run multiple times.

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def start_training(self):
212212

213213
pipeline = self.load_checkpoint()
214214
# Generate a sample before training to compare against generated sample after training.
215-
# pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
215+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
216216

217217
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
218218
# save some memory.
@@ -230,8 +230,8 @@ def start_training(self):
230230
# Returns pipeline with trained transformer state
231231
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
232232

233-
# posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
234-
# print_ssim(pretrained_video_path, posttrained_video_path)
233+
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
234+
print_ssim(pretrained_video_path, posttrained_video_path)
235235

236236
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
237237
mesh = pipeline.mesh
@@ -440,11 +440,6 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps):
440440
# Reconstruct the model from its definition and parameters
441441
model = nnx.merge(state.graphdef, params, state.rest_of_state)
442442

443-
# Prepare inputs
444-
# latents = data["latents"].astype(config.weights_dtype)
445-
# encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
446-
# timesteps = data["timesteps"].astype("int64")
447-
448443
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
449444
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
450445

@@ -469,12 +464,11 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps):
469464
# Directly compute the loss without calculating gradients.
470465
# The model's state.params are used but not updated.
471466
bs = len(data["latents"])
472-
single_batch_size = min(8, config.global_batch_size_to_train_on)
467+
single_batch_size = min(config.eval_max_processed_batch_size, config.global_batch_size_to_train_on)
473468
losses = jnp.zeros(bs)
474469
for i in range(0, bs, single_batch_size):
475470
start = i
476471
end = min(i + single_batch_size, bs)
477-
jax.debug.print("Eval step processing samples {start} to {end}", start=start, end=end)
478472
latents= data["latents"][start:end, :].astype(config.weights_dtype)
479473
encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype)
480474
timesteps = data["timesteps"][start:end].astype("int64")
@@ -483,7 +477,6 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps):
483477

484478
# Structure the metrics for logging and aggregation
485479
metrics = {"scalar": {"learning/eval_loss": losses}}
486-
jax.debug.print("Eval step losses: {losses}", losses=losses)
487480

488481
# Return the computed metrics and the new RNG key for the next eval step
489482
return metrics, new_rng

0 commit comments

Comments
 (0)