Skip to content

Commit 32c043c

Browse files
committed
add evaluation into WAN pipeline
1 parent fc46fcc commit 32c043c

4 files changed

Lines changed: 190 additions & 37 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,6 @@ quantization: ''
287287
# Shard the range finding operation for quantization. By default this is set to number of slices.
288288
quantization_local_shard_count: -1
289289
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
290-
290+
# Eval model on per eval_every steps. -1 means don't eval.
291+
eval_every: -1
292+
eval_data_dir: ""

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,8 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81-
8281
def make_cached_tfrecord_iterator(
83-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
82+
dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training: bool
8483
):
8584
"""
8685
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
@@ -91,37 +90,58 @@ def _parse_tfrecord_fn(example):
9190
return tf.io.parse_single_example(example, feature_description)
9291

9392
# This pipeline reads the sharded files and applies the parsing and preparation.
94-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
95-
96-
train_ds = (
97-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
98-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
99-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
100-
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
101-
.shuffle(global_batch_size * 10)
93+
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
94+
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
95+
96+
# --- PADDING LOGIC FOR EVALUATION ---
97+
if not is_training:
98+
num_eval_samples = 0
99+
for _ in ds:
100+
num_eval_samples += 1
101+
102+
remainder = num_eval_samples % global_batch_size
103+
if remainder != 0:
104+
num_to_pad = global_batch_size - remainder
105+
# Create a dataset of padding samples from the beginning
106+
padding_ds = ds.take(num_to_pad)
107+
# Add the padding samples to the end
108+
ds = ds.concatenate(padding_ds)
109+
print(f"Padded evaluation dataset with {num_to_pad} samples.")
110+
111+
ds = (
112+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
113+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
114+
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
115+
)
116+
if is_training:
117+
ds = (
118+
ds.shuffle(global_batch_size * 10)
102119
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
103120
.repeat(-1)
104121
.prefetch(AUTOTUNE)
105-
)
122+
)
123+
# For Evaluation
124+
else:
125+
ds = (
126+
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
127+
.prefetch(AUTOTUNE)
128+
)
106129

107130
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
108-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
109-
return train_iter
131+
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
132+
return iter
110133

111134

112135
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
113-
def make_tfrecord_iterator(
114-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
136+
def _make_tfrecord_iterator(
137+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training: bool
115138
):
116-
"""Iterator for TFRecord format. For Laion dataset,
117-
check out preparation script
118-
maxdiffusion/pedagogical_examples/to_tfrecords.py
119-
"""
120139
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
121140
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
122141
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
123142

124143
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
144+
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
125145
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
126146

127147
if (
@@ -131,13 +151,14 @@ def make_tfrecord_iterator(
131151
and config.load_tfrecord_cached
132152
):
133153
return make_cached_tfrecord_iterator(
134-
config,
135154
dataloading_host_index,
136155
dataloading_host_count,
137156
mesh,
138157
global_batch_size,
139158
feature_description,
140159
prepare_sample_fn,
160+
dataset_path,
161+
is_training
141162
)
142163

143164
feature_description = {
@@ -153,17 +174,54 @@ def prepare_sample(features):
153174
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
154175
return {"pixel_values": moments, "input_ids": clip_embeddings}
155176

156-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
157-
train_ds = (
158-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
159-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
160-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
161-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
162-
.shuffle(global_batch_size * 10)
177+
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
178+
179+
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
180+
# --- PADDING LOGIC FOR EVALUATION ---
181+
if not is_training:
182+
num_eval_samples = 0
183+
for _ in ds:
184+
num_eval_samples += 1
185+
186+
remainder = num_eval_samples % global_batch_size
187+
if remainder != 0:
188+
num_to_pad = global_batch_size - remainder
189+
# Create a dataset of padding samples from the beginning
190+
padding_ds = ds.take(num_to_pad)
191+
# Add the padding samples to the end
192+
ds = ds.concatenate(padding_ds)
193+
print(f"Padded evaluation dataset with {num_to_pad} samples.")
194+
195+
ds = (
196+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
197+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
198+
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
199+
)
200+
if is_training:
201+
ds = (
202+
ds.shuffle(global_batch_size * 10)
163203
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
164204
.repeat(-1)
165205
.prefetch(AUTOTUNE)
166-
)
206+
)
207+
# For Evaluation
208+
else:
209+
ds = (
210+
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
211+
.prefetch(AUTOTUNE)
212+
)
167213

168-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
169-
return train_iter
214+
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
215+
return iter
216+
217+
def make_tfrecord_iterator(
218+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
219+
):
220+
"""Iterator for TFRecord format. For Laion dataset,
221+
check out preparation script
222+
maxdiffusion/pedagogical_examples/to_tfrecords.py
223+
"""
224+
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
225+
# TODO: refactor to support evaluation on all dataset format.
226+
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
227+
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def make_data_iterator(
5252
image_transforms_fn=None,
5353
feature_description=None,
5454
prepare_sample_fn=None,
55+
is_training=True,
5556
):
5657
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
5758

@@ -106,6 +107,7 @@ def make_data_iterator(
106107
global_batch_size,
107108
feature_description,
108109
prepare_sample_fn,
110+
is_training
109111
)
110112
else:
111113
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_data_shardings(self, mesh):
101101
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
102102
return data_sharding
103103

104-
def load_dataset(self, mesh):
104+
def load_dataset(self, mesh, is_training=True):
105105
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
106106
# Image pre-training - txt2img 256px
107107
# Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
@@ -134,6 +134,7 @@ def prepare_sample(features):
134134
config.global_batch_size_to_load,
135135
feature_description=feature_description,
136136
prepare_sample_fn=prepare_sample,
137+
is_training=is_training,
137138
)
138139
return data_iterator
139140

@@ -145,20 +146,20 @@ def start_training(self):
145146
# Generate a sample before training to compare against generated sample after training.
146147
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
147148
mesh = pipeline.mesh
148-
data_iterator = self.load_dataset(mesh)
149+
train_data_iterator = self.load_dataset(mesh, is_training=True)
149150

150151
# Load FlowMatch scheduler
151152
scheduler, scheduler_state = self.create_scheduler()
152153
pipeline.scheduler = scheduler
153154
pipeline.scheduler_state = scheduler_state
154155
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
155156
# Returns pipeline with trained transformer state
156-
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
157+
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
157158

158159
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
159160
print_ssim(pretrained_video_path, posttrained_video_path)
160161

161-
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
162+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
162163
mesh = pipeline.mesh
163164
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
164165

@@ -193,6 +194,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
193194
out_shardings=(state_shardings, None, None, None),
194195
donate_argnums=(0,),
195196
)
197+
p_eval_step = jax.jit(
198+
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
199+
in_shardings=(state_shardings, data_shardings, None, None),
200+
out_shardings=(None, None),
201+
)
202+
196203
rng = jax.random.key(self.config.seed)
197204
start_step = 0
198205
last_step_completion = datetime.datetime.now()
@@ -209,13 +216,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
209216
per_device_tflops = self.calculate_tflops(pipeline)
210217

211218
scheduler_state = pipeline.scheduler_state
212-
example_batch = load_next_batch(data_iterator, None, self.config)
219+
example_batch = load_next_batch(train_data_iterator, None, self.config)
213220
with ThreadPoolExecutor(max_workers=1) as executor:
214221
for step in np.arange(start_step, self.config.max_train_steps):
215222
if self.config.enable_profiler and step == first_profiling_step:
216223
max_utils.activate_profiler(self.config)
217224
start_step_time = datetime.datetime.now()
218-
next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config)
225+
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
219226
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
220227
self.config.logical_axis_rules
221228
):
@@ -231,6 +238,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
231238
)
232239
if self.config.write_metrics:
233240
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
241+
242+
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
243+
# Re-create the iterator each time you start evaluation to reset it
244+
# This assumes your data loading logic can be called to get a fresh iterator.
245+
eval_data_iterator = self.load_dataset(mesh, is_training=False)
246+
eval_rng = jax.random.key(self.config.seed + step)
247+
eval_metrics = []
248+
# Loop indefinitely until the iterator is exhausted
249+
while True:
250+
try:
251+
with mesh:
252+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
253+
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
254+
eval_metrics.append(metrics["scalar"]["eval/loss"])
255+
except StopIteration:
256+
# This block is executed when the iterator has no more data
257+
break
258+
# Check if any evaluation was actually performed
259+
if eval_metrics:
260+
eval_loss = jnp.mean(jnp.array(eval_metrics))
261+
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
262+
if writer:
263+
writer.add_scalar("eval/loss", eval_loss, step)
264+
else:
265+
max_logging.log(f"Step {step}, evaluation dataset was empty.")
234266
example_batch = next_batch_future.result()
235267

236268
_metrics_queue.put(None)
@@ -286,3 +318,62 @@ def loss_fn(params):
286318
new_state = state.apply_gradients(grads=grads)
287319
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
288320
return new_state, scheduler_state, metrics, new_rng
321+
322+
def eval_step(state, data, rng, scheduler_state, scheduler, config):
323+
"""
324+
Computes the evaluation loss for a single batch without updating model weights.
325+
"""
326+
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
327+
328+
# This ensures the batch size is consistent, though it might be redundant
329+
# if the evaluation dataloader is already configured correctly.
330+
for k, v in data.items():
331+
data[k] = v[: config.global_batch_size_to_train_on, :]
332+
333+
# The loss function logic is identical to training. We are evaluating the model's
334+
# ability to perform its core training objective (e.g., denoising).
335+
def loss_fn(params):
336+
# Reconstruct the model from its definition and parameters
337+
model = nnx.merge(state.graphdef, params, state.rest_of_state)
338+
339+
# Prepare inputs
340+
latents = data["latents"].astype(config.weights_dtype)
341+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
342+
bsz = latents.shape[0]
343+
344+
# Sample random timesteps and noise, just as in a training step
345+
timesteps = jax.random.randint(
346+
timestep_rng,
347+
(bsz,),
348+
0,
349+
scheduler.config.num_train_timesteps,
350+
)
351+
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
352+
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
353+
354+
# Get the model's prediction
355+
model_pred = model(
356+
hidden_states=noisy_latents,
357+
timestep=timesteps,
358+
encoder_hidden_states=encoder_hidden_states,
359+
)
360+
361+
# Calculate the loss against the target
362+
training_target = scheduler.training_target(latents, noise, timesteps)
363+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
364+
loss = (training_target - model_pred) ** 2
365+
loss = loss * training_weight
366+
loss = jnp.mean(loss)
367+
368+
return loss
369+
370+
# --- Key Difference from train_step ---
371+
# Directly compute the loss without calculating gradients.
372+
# The model's state.params are used but not updated.
373+
loss = loss_fn(state.params)
374+
375+
# Structure the metrics for logging and aggregation
376+
metrics = {"scalar": {"eval/loss": loss}}
377+
378+
# Return the computed metrics and the new RNG key for the next eval step
379+
return metrics, new_rng

0 commit comments

Comments
 (0)