Skip to content

Commit 47c8305

Browse files
committed
eval pipeline
1 parent 75d16a4 commit 47c8305

2 files changed

Lines changed: 78 additions & 29 deletions

File tree

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,17 @@ def float_feature_list(value):
5555
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
5656

5757

58-
def create_example(latent, hidden_states):
58+
def create_example(latent, hidden_states, timestep=None):
5959
latent = tf.io.serialize_tensor(latent)
6060
hidden_states = tf.io.serialize_tensor(hidden_states)
6161
feature = {
6262
"latents": bytes_feature(latent),
6363
"encoder_hidden_states": bytes_feature(hidden_states),
6464
}
65+
# Add timestep feature if it is provided
66+
if timestep is not None:
67+
feature["timesteps"] = int64_feature(timestep)
68+
6569
example = tf.train.Example(features=tf.train.Features(feature=feature))
6670
return example.SerializeToString()
6771

@@ -80,6 +84,11 @@ def generate_dataset(config):
8084
)
8185
shard_record_count = 0
8286

87+
# Define timesteps and bucket configuration
88+
timesteps_list = [125, 250, 375, 500, 625, 750, 875]
89+
bucket_size = 60
90+
num_samples_to_process = 420
91+
8392
# Load dataset
8493
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
8594
with open(metadata_path, "r", newline="") as file:
@@ -102,7 +111,18 @@ def generate_dataset(config):
102111
# Save them as float32 because numpy cannot read bfloat16.
103112
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
104113
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
105-
writer.write(create_example(latent, prompt_embeds))
114+
115+
# Determine the timestep for the first 420 samples
116+
current_timestep = None
117+
if global_record_count < num_samples_to_process:
118+
bucket_index = global_record_count // bucket_size
119+
current_timestep = timesteps_list[bucket_index]
120+
else:
121+
print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}")
122+
return
123+
124+
# Write the example, including the timestep if applicable
125+
writer.write(create_example(latent, prompt_embeds, timestep=current_timestep))
106126
shard_record_count += 1
107127
global_record_count += 1
108128

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ def get_data_shardings(self, mesh):
156156
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
157157
return data_sharding
158158

159+
def get_eval_data_shardings(self, mesh):
160+
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
161+
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": None}
162+
return data_sharding
163+
159164
def load_dataset(self, mesh, is_training=True):
160165
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
161166
# Image pre-training - txt2img 256px
@@ -170,25 +175,38 @@ def load_dataset(self, mesh, is_training=True):
170175
raise ValueError(
171176
"Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
172177
)
173-
174-
feature_description = {
178+
179+
feature_description_train = {
175180
"latents": tf.io.FixedLenFeature([], tf.string),
176181
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
177182
}
178183

179-
def prepare_sample(features):
184+
def prepare_sample_train(features):
180185
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
181186
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
182187
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states}
188+
189+
feature_description_eval = {
190+
"latents": tf.io.FixedLenFeature([], tf.string),
191+
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
192+
"timesteps": tf.io.FixedLenFeature([], tf.int64),
193+
}
194+
195+
def prepare_sample_eval(features):
196+
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
197+
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
198+
timesteps = features["timesteps"]
199+
print(f"timesteps in prepare_sample_eval: {timesteps}")
200+
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps}
183201

184202
data_iterator = make_data_iterator(
185203
config,
186204
jax.process_index(),
187205
jax.process_count(),
188206
mesh,
189207
config.global_batch_size_to_load,
190-
feature_description=feature_description,
191-
prepare_sample_fn=prepare_sample,
208+
feature_description=feature_description_train if is_training else feature_description_eval,
209+
prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval,
192210
is_training=is_training,
193211
)
194212
return data_iterator
@@ -197,7 +215,7 @@ def start_training(self):
197215

198216
pipeline = self.load_checkpoint()
199217
# Generate a sample before training to compare against generated sample after training.
200-
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
218+
# pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
201219

202220
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
203221
# save some memory.
@@ -215,8 +233,8 @@ def start_training(self):
215233
# Returns pipeline with trained transformer state
216234
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
217235

218-
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
219-
print_ssim(pretrained_video_path, posttrained_video_path)
236+
# posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
237+
# print_ssim(pretrained_video_path, posttrained_video_path)
220238

221239
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
222240
mesh = pipeline.mesh
@@ -231,6 +249,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
231249
state = jax.lax.with_sharding_constraint(state, state_spec)
232250
state_shardings = nnx.get_named_sharding(state, mesh)
233251
data_shardings = self.get_data_shardings(mesh)
252+
eval_data_shardings = self.get_eval_data_shardings(mesh)
234253

235254
writer = max_utils.initialize_summary_writer(self.config)
236255
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
@@ -255,11 +274,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255274
)
256275
p_eval_step = jax.jit(
257276
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
258-
in_shardings=(state_shardings, data_shardings, None, None),
277+
in_shardings=(state_shardings, eval_data_shardings, None, None),
259278
out_shardings=(None, None),
260279
)
261280

262281
rng = jax.random.key(self.config.seed)
282+
rng, eval_rng_key = jax.random.split(rng)
263283
start_step = 0
264284
last_step_completion = datetime.datetime.now()
265285
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
@@ -305,24 +325,36 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
305325
# Re-create the iterator each time you start evaluation to reset it
306326
# This assumes your data loading logic can be called to get a fresh iterator.
307327
eval_data_iterator = self.load_dataset(mesh, is_training=False)
308-
eval_rng = jax.random.key(self.config.seed + step)
309-
eval_metrics = []
328+
eval_rng = eval_rng_key
329+
eval_losses_by_timestep = {}
310330
# Loop indefinitely until the iterator is exhausted
311331
while True:
312332
try:
313333
with mesh:
314334
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
315335
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
316-
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
336+
loss = metrics["scalar"]["learning/eval_loss"]
337+
timestep = int(eval_batch["timesteps"][0])
338+
if timestep not in eval_losses_by_timestep:
339+
eval_losses_by_timestep[timestep] = []
340+
eval_losses_by_timestep[timestep].append(loss)
317341
except StopIteration:
318342
# This block is executed when the iterator has no more data
319343
break
320344
# Check if any evaluation was actually performed
321-
if eval_metrics:
322-
eval_loss = jnp.mean(jnp.array(eval_metrics))
323-
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
345+
if eval_losses_by_timestep:
346+
mean_per_timestep = []
347+
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
348+
for timestep, losses in sorted(eval_losses_by_timestep.items()):
349+
losses = jnp.array(losses)
350+
losses = losses[: min(60, len(losses))]
351+
mean_loss = jnp.mean(losses)
352+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
353+
mean_per_timestep.append(mean_loss)
354+
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
355+
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
324356
if writer:
325-
writer.add_scalar("learning/eval_loss", eval_loss, step)
357+
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
326358
else:
327359
max_logging.log(f"Step {step}, evaluation dataset was empty.")
328360
example_batch = next_batch_future.result()
@@ -394,12 +426,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394426
"""
395427
Computes the evaluation loss for a single batch without updating model weights.
396428
"""
397-
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
429+
_, new_rng = jax.random.split(rng, num=2)
398430

399431
# This ensures the batch size is consistent, though it might be redundant
400432
# if the evaluation dataloader is already configured correctly.
401433
for k, v in data.items():
402-
data[k] = v[: config.global_batch_size_to_train_on, :]
434+
if k != "timesteps":
435+
data[k] = v[: config.global_batch_size_to_train_on, :]
436+
else:
437+
data[k] = v[: config.global_batch_size_to_train_on]
403438

404439
# The loss function logic is identical to training. We are evaluating the model's
405440
# ability to perform its core training objective (e.g., denoising).
@@ -410,15 +445,8 @@ def loss_fn(params):
410445
# Prepare inputs
411446
latents = data["latents"].astype(config.weights_dtype)
412447
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
413-
bsz = latents.shape[0]
448+
timesteps = data["timesteps"].astype("int64")
414449

415-
# Sample random timesteps and noise, just as in a training step
416-
timesteps = jax.random.randint(
417-
timestep_rng,
418-
(bsz,),
419-
0,
420-
scheduler.config.num_train_timesteps,
421-
)
422450
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
423451
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
424452

@@ -427,6 +455,7 @@ def loss_fn(params):
427455
hidden_states=noisy_latents,
428456
timestep=timesteps,
429457
encoder_hidden_states=encoder_hidden_states,
458+
deterministic=True,
430459
)
431460

432461
# Calculate the loss against the target
@@ -447,4 +476,4 @@ def loss_fn(params):
447476
metrics = {"scalar": {"learning/eval_loss": loss}}
448477

449478
# Return the computed metrics and the new RNG key for the next eval step
450-
return metrics, new_rng
479+
return metrics, new_rng,

0 commit comments

Comments
 (0)