Skip to content

Commit 2c9c73d

Browse files
committed
version 3
1 parent 82502da commit 2c9c73d

1 file changed

Lines changed: 1 addition & 9 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def get_data_shardings(self, mesh):
158158

159159
def get_eval_data_shardings(self, mesh):
160160
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
161-
timesteps_sharding = jax.sharding.NamedSharding(mesh, P('data'))
162-
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": timesteps_sharding}
161+
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding}
163162
return data_sharding
164163

165164
def load_dataset(self, mesh, is_training=True):
@@ -194,7 +193,6 @@ def prepare_sample_eval(features):
194193
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
195194
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
196195
timesteps = features["timesteps"]
197-
tf.print("timesteps in prepare_sample_eval:", timesteps)
198196
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps}
199197

200198
data_iterator = make_data_iterator(
@@ -330,9 +328,6 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
330328
try:
331329
with mesh:
332330
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
333-
eval_batch["timesteps"] = jax.device_put(
334-
eval_batch["timesteps"], eval_data_shardings["timesteps"]
335-
)
336331
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
337332
losses = metrics["scalar"]["learning/eval_loss"]
338333
timesteps = eval_batch["timesteps"]
@@ -432,9 +427,6 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
432427
"""
433428
_, new_rng = jax.random.split(rng, num=2)
434429

435-
# This ensures the batch size is consistent, though it might be redundant
436-
# if the evaluation dataloader is already configured correctly.
437-
jax.debug.print("timesteps before clip: {x}", x=data["timesteps"])
438430
# The loss function logic is identical to training. We are evaluating the model's
439431
# ability to perform its core training objective (e.g., denoising).
440432
def loss_fn(params):

0 commit comments

Comments
 (0)