Skip to content

Commit 3f4d3d4

Browse files
committed
change 1
1 parent edcd3c2 commit 3f4d3d4

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def get_data_shardings(self, mesh):
158158

159159
def get_eval_data_shardings(self, mesh):
160160
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
161-
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": None}
161+
timesteps_sharding = jax.sharding.NamedSharding(mesh, P('data'))
162+
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": timesteps_sharding}
162163
return data_sharding
163164

164165
def load_dataset(self, mesh, is_training=True):
@@ -196,7 +197,7 @@ def prepare_sample_eval(features):
196197
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
197198
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
198199
timesteps = features["timesteps"]
199-
print(f"timesteps in prepare_sample_eval: {timesteps}")
200+
tf.print("timesteps in prepare_sample_eval:", timesteps)
200201
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps}
201202

202203
data_iterator = make_data_iterator(
@@ -332,9 +333,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
332333
try:
333334
with mesh:
334335
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
336+
eval_batch["timesteps"] = jax.device_put(
337+
eval_batch["timesteps"], eval_data_shardings["timesteps"]
338+
)
335339
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
336340
loss = metrics["scalar"]["learning/eval_loss"]
337341
timestep = int(eval_batch["timesteps"][0])
342+
jax.debug.print("timesteps in eval_step: {x}", x=timestep)
338343
if timestep not in eval_losses_by_timestep:
339344
eval_losses_by_timestep[timestep] = []
340345
eval_losses_by_timestep[timestep].append(loss)
@@ -349,7 +354,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
349354
losses = jnp.array(losses)
350355
losses = losses[: min(60, len(losses))]
351356
mean_loss = jnp.mean(losses)
352-
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
357+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}")
353358
mean_per_timestep.append(mean_loss)
354359
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
355360
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
@@ -430,11 +435,13 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
430435

431436
# This ensures the batch size is consistent, though it might be redundant
432437
# if the evaluation dataloader is already configured correctly.
438+
jax.debug.print("timesteps before clip: {x}", x=data["timesteps"])
433439
for k, v in data.items():
434440
if k != "timesteps":
435441
data[k] = v[: config.global_batch_size_to_train_on, :]
436442
else:
437443
data[k] = v[: config.global_batch_size_to_train_on]
444+
jax.debug.print("timesteps after clip: {x}", x=data["timesteps"])
438445

439446
# The loss function logic is identical to training. We are evaluating the model's
440447
# ability to perform its core training objective (e.g., denoising).

0 commit comments

Comments
 (0)