Skip to content

Commit 38283b5

Browse files
authored
add WAN evaluation pipeline (#232)
* add evaluation into WAN pipeline * refactor tfrecord function and change eval loss name * fix lint * use max_logger
1 parent c44f0e5 commit 38283b5

4 files changed

Lines changed: 165 additions & 70 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,6 @@ use_qwix_quantization: False # Whether to use qwix for quantization. If set to T
291291
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
292292
quantization_calibration_method: "absmax"
293293

294+
# Eval model on per eval_every steps. -1 means don't eval.
295+
eval_every: -1
296+
eval_data_dir: ""

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tensorflow.experimental.numpy as tnp
2020
from datasets import load_dataset, load_from_disk
2121
import jax
22-
from maxdiffusion import multihost_dataloading
22+
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
2525

@@ -78,92 +78,91 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81-
82-
def make_cached_tfrecord_iterator(
83-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
84-
):
85-
"""
86-
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
87-
latents, input_ids, prompt_embeds, and text_embeds.
88-
"""
89-
90-
def _parse_tfrecord_fn(example):
91-
return tf.io.parse_single_example(example, feature_description)
92-
93-
# This pipeline reads the sharded files and applies the parsing and preparation.
94-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
95-
96-
train_ds = (
97-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
98-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
99-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
100-
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
101-
.shuffle(global_batch_size * 10)
102-
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
103-
.repeat(-1)
104-
.prefetch(AUTOTUNE)
105-
)
106-
107-
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
108-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
109-
return train_iter
110-
111-
11281
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
113-
def make_tfrecord_iterator(
114-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
82+
def _make_tfrecord_iterator(
83+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
11584
):
116-
"""Iterator for TFRecord format. For Laion dataset,
117-
check out preparation script
118-
maxdiffusion/pedagogical_examples/to_tfrecords.py
119-
"""
12085
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
12186
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
12287
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
88+
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
12389

12490
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
12591
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
12692

127-
if (
128-
config.cache_latents_text_encoder_outputs
129-
and is_dataset_dir_valid
130-
and "load_tfrecord_cached" in config.get_keys()
131-
and config.load_tfrecord_cached
132-
):
133-
return make_cached_tfrecord_iterator(
134-
config,
135-
dataloading_host_index,
136-
dataloading_host_count,
137-
mesh,
138-
global_batch_size,
139-
feature_description,
140-
prepare_sample_fn,
141-
)
93+
# Determine whether to use the "cached" dataset, which requires externally
94+
# provided parsing functions, or the default one with its internal parsing logic.
95+
make_cached_tfrecord_iterator = (
96+
config.cache_latents_text_encoder_outputs
97+
and is_dataset_dir_valid
98+
and "load_tfrecord_cached" in config.get_keys()
99+
and config.load_tfrecord_cached
100+
)
142101

143102
feature_description = {
144103
"moments": tf.io.FixedLenFeature([], tf.string),
145104
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
146105
}
147106

107+
used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
108+
148109
def _parse_tfrecord_fn(example):
149-
return tf.io.parse_single_example(example, feature_description)
110+
return tf.io.parse_single_example(example, used_feature_description)
150111

151112
def prepare_sample(features):
152113
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
153114
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
154115
return {"pixel_values": moments, "input_ids": clip_embeddings}
155116

156-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
157-
train_ds = (
158-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
159-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
160-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
161-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
162-
.shuffle(global_batch_size * 10)
117+
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
118+
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
119+
120+
# --- PADDING LOGIC FOR EVALUATION ---
121+
if not is_training:
122+
num_eval_samples = 0
123+
for _ in ds:
124+
num_eval_samples += 1
125+
126+
remainder = num_eval_samples % global_batch_size
127+
if remainder != 0:
128+
num_to_pad = global_batch_size - remainder
129+
# Create a dataset of padding samples from the beginning
130+
padding_ds = ds.take(num_to_pad)
131+
# Add the padding samples to the end
132+
ds = ds.concatenate(padding_ds)
133+
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
134+
135+
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
136+
ds = (
137+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
138+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
139+
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
140+
)
141+
if is_training:
142+
ds = (
143+
ds.shuffle(global_batch_size * 10)
163144
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
164145
.repeat(-1)
165146
.prefetch(AUTOTUNE)
166-
)
147+
)
148+
# For Evaluation
149+
else:
150+
ds = (
151+
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
152+
.prefetch(AUTOTUNE)
153+
)
167154

168-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
169-
return train_iter
155+
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
156+
return iter
157+
158+
def make_tfrecord_iterator(
159+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
160+
):
161+
"""Iterator for TFRecord format. For Laion dataset,
162+
check out preparation script
163+
maxdiffusion/pedagogical_examples/to_tfrecords.py
164+
"""
165+
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
166+
# TODO: refactor to support evaluation on all dataset format.
167+
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
168+
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def make_data_iterator(
5252
image_transforms_fn=None,
5353
feature_description=None,
5454
prepare_sample_fn=None,
55+
is_training=True,
5556
):
5657
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
5758

@@ -106,6 +107,7 @@ def make_data_iterator(
106107
global_batch_size,
107108
feature_description,
108109
prepare_sample_fn,
110+
is_training
109111
)
110112
else:
111113
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_data_shardings(self, mesh):
101101
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
102102
return data_sharding
103103

104-
def load_dataset(self, mesh):
104+
def load_dataset(self, mesh, is_training=True):
105105
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
106106
# Image pre-training - txt2img 256px
107107
# Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
@@ -134,6 +134,7 @@ def prepare_sample(features):
134134
config.global_batch_size_to_load,
135135
feature_description=feature_description,
136136
prepare_sample_fn=prepare_sample,
137+
is_training=is_training,
137138
)
138139
return data_iterator
139140

@@ -145,20 +146,20 @@ def start_training(self):
145146
# Generate a sample before training to compare against generated sample after training.
146147
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
147148
mesh = pipeline.mesh
148-
data_iterator = self.load_dataset(mesh)
149+
train_data_iterator = self.load_dataset(mesh, is_training=True)
149150

150151
# Load FlowMatch scheduler
151152
scheduler, scheduler_state = self.create_scheduler()
152153
pipeline.scheduler = scheduler
153154
pipeline.scheduler_state = scheduler_state
154155
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
155156
# Returns pipeline with trained transformer state
156-
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
157+
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
157158

158159
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
159160
print_ssim(pretrained_video_path, posttrained_video_path)
160161

161-
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
162+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
162163
mesh = pipeline.mesh
163164
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
164165

@@ -193,6 +194,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
193194
out_shardings=(state_shardings, None, None, None),
194195
donate_argnums=(0,),
195196
)
197+
p_eval_step = jax.jit(
198+
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
199+
in_shardings=(state_shardings, data_shardings, None, None),
200+
out_shardings=(None, None),
201+
)
202+
196203
rng = jax.random.key(self.config.seed)
197204
start_step = 0
198205
last_step_completion = datetime.datetime.now()
@@ -209,13 +216,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
209216
per_device_tflops = self.calculate_tflops(pipeline)
210217

211218
scheduler_state = pipeline.scheduler_state
212-
example_batch = load_next_batch(data_iterator, None, self.config)
219+
example_batch = load_next_batch(train_data_iterator, None, self.config)
213220
with ThreadPoolExecutor(max_workers=1) as executor:
214221
for step in np.arange(start_step, self.config.max_train_steps):
215222
if self.config.enable_profiler and step == first_profiling_step:
216223
max_utils.activate_profiler(self.config)
217224
start_step_time = datetime.datetime.now()
218-
next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config)
225+
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
219226
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
220227
self.config.logical_axis_rules
221228
):
@@ -231,6 +238,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
231238
)
232239
if self.config.write_metrics:
233240
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
241+
242+
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
243+
# Re-create the iterator each time you start evaluation to reset it
244+
# This assumes your data loading logic can be called to get a fresh iterator.
245+
eval_data_iterator = self.load_dataset(mesh, is_training=False)
246+
eval_rng = jax.random.key(self.config.seed + step)
247+
eval_metrics = []
248+
# Loop indefinitely until the iterator is exhausted
249+
while True:
250+
try:
251+
with mesh:
252+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
253+
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
254+
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
255+
except StopIteration:
256+
# This block is executed when the iterator has no more data
257+
break
258+
# Check if any evaluation was actually performed
259+
if eval_metrics:
260+
eval_loss = jnp.mean(jnp.array(eval_metrics))
261+
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
262+
if writer:
263+
writer.add_scalar("learning/eval_loss", eval_loss, step)
264+
else:
265+
max_logging.log(f"Step {step}, evaluation dataset was empty.")
234266
example_batch = next_batch_future.result()
235267

236268
_metrics_queue.put(None)
@@ -286,3 +318,62 @@ def loss_fn(params):
286318
new_state = state.apply_gradients(grads=grads)
287319
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
288320
return new_state, scheduler_state, metrics, new_rng
321+
322+
def eval_step(state, data, rng, scheduler_state, scheduler, config):
323+
"""
324+
Computes the evaluation loss for a single batch without updating model weights.
325+
"""
326+
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
327+
328+
# This ensures the batch size is consistent, though it might be redundant
329+
# if the evaluation dataloader is already configured correctly.
330+
for k, v in data.items():
331+
data[k] = v[: config.global_batch_size_to_train_on, :]
332+
333+
# The loss function logic is identical to training. We are evaluating the model's
334+
# ability to perform its core training objective (e.g., denoising).
335+
def loss_fn(params):
336+
# Reconstruct the model from its definition and parameters
337+
model = nnx.merge(state.graphdef, params, state.rest_of_state)
338+
339+
# Prepare inputs
340+
latents = data["latents"].astype(config.weights_dtype)
341+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
342+
bsz = latents.shape[0]
343+
344+
# Sample random timesteps and noise, just as in a training step
345+
timesteps = jax.random.randint(
346+
timestep_rng,
347+
(bsz,),
348+
0,
349+
scheduler.config.num_train_timesteps,
350+
)
351+
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
352+
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
353+
354+
# Get the model's prediction
355+
model_pred = model(
356+
hidden_states=noisy_latents,
357+
timestep=timesteps,
358+
encoder_hidden_states=encoder_hidden_states,
359+
)
360+
361+
# Calculate the loss against the target
362+
training_target = scheduler.training_target(latents, noise, timesteps)
363+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
364+
loss = (training_target - model_pred) ** 2
365+
loss = loss * training_weight
366+
loss = jnp.mean(loss)
367+
368+
return loss
369+
370+
# --- Key Difference from train_step ---
371+
# Directly compute the loss without calculating gradients.
372+
# The model's state.params are used but not updated.
373+
loss = loss_fn(state.params)
374+
375+
# Structure the metrics for logging and aggregation
376+
metrics = {"scalar": {"learning/eval_loss": loss}}
377+
378+
# Return the computed metrics and the new RNG key for the next eval step
379+
return metrics, new_rng

0 commit comments

Comments
 (0)