Skip to content

Commit a7711f3

Browse files
committed
≈Merge branch 'main' into remove_conv_sharding
2 parents 8040f29 + 38283b5 commit a7711f3

4 files changed

Lines changed: 165 additions & 70 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,6 @@ use_qwix_quantization: False # Whether to use qwix for quantization. If set to T
293293
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
294294
quantization_calibration_method: "absmax"
295295

296+
# Eval model on per eval_every steps. -1 means don't eval.
297+
eval_every: -1
298+
eval_data_dir: ""

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tensorflow.experimental.numpy as tnp
2020
from datasets import load_dataset, load_from_disk
2121
import jax
22-
from maxdiffusion import multihost_dataloading
22+
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
2525

@@ -78,92 +78,91 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81-
82-
def make_cached_tfrecord_iterator(
83-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
84-
):
85-
"""
86-
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
87-
latents, input_ids, prompt_embeds, and text_embeds.
88-
"""
89-
90-
def _parse_tfrecord_fn(example):
91-
return tf.io.parse_single_example(example, feature_description)
92-
93-
# This pipeline reads the sharded files and applies the parsing and preparation.
94-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
95-
96-
train_ds = (
97-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
98-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
99-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
100-
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
101-
.shuffle(global_batch_size * 10)
102-
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
103-
.repeat(-1)
104-
.prefetch(AUTOTUNE)
105-
)
106-
107-
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
108-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
109-
return train_iter
110-
111-
11281
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
113-
def make_tfrecord_iterator(
114-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
82+
def _make_tfrecord_iterator(
83+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
11584
):
116-
"""Iterator for TFRecord format. For Laion dataset,
117-
check out preparation script
118-
maxdiffusion/pedagogical_examples/to_tfrecords.py
119-
"""
12085
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
12186
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
12287
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
88+
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
12389

12490
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
12591
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
12692

127-
if (
128-
config.cache_latents_text_encoder_outputs
129-
and is_dataset_dir_valid
130-
and "load_tfrecord_cached" in config.get_keys()
131-
and config.load_tfrecord_cached
132-
):
133-
return make_cached_tfrecord_iterator(
134-
config,
135-
dataloading_host_index,
136-
dataloading_host_count,
137-
mesh,
138-
global_batch_size,
139-
feature_description,
140-
prepare_sample_fn,
141-
)
93+
# Determine whether to use the "cached" dataset, which requires externally
94+
# provided parsing functions, or the default one with its internal parsing logic.
95+
make_cached_tfrecord_iterator = (
96+
config.cache_latents_text_encoder_outputs
97+
and is_dataset_dir_valid
98+
and "load_tfrecord_cached" in config.get_keys()
99+
and config.load_tfrecord_cached
100+
)
142101

143102
feature_description = {
144103
"moments": tf.io.FixedLenFeature([], tf.string),
145104
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
146105
}
147106

107+
used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
108+
148109
def _parse_tfrecord_fn(example):
149-
return tf.io.parse_single_example(example, feature_description)
110+
return tf.io.parse_single_example(example, used_feature_description)
150111

151112
def prepare_sample(features):
152113
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
153114
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
154115
return {"pixel_values": moments, "input_ids": clip_embeddings}
155116

156-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
157-
train_ds = (
158-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
159-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
160-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
161-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
162-
.shuffle(global_batch_size * 10)
117+
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
118+
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
119+
120+
# --- PADDING LOGIC FOR EVALUATION ---
121+
if not is_training:
122+
num_eval_samples = 0
123+
for _ in ds:
124+
num_eval_samples += 1
125+
126+
remainder = num_eval_samples % global_batch_size
127+
if remainder != 0:
128+
num_to_pad = global_batch_size - remainder
129+
# Create a dataset of padding samples from the beginning
130+
padding_ds = ds.take(num_to_pad)
131+
# Add the padding samples to the end
132+
ds = ds.concatenate(padding_ds)
133+
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
134+
135+
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
136+
ds = (
137+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
138+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
139+
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
140+
)
141+
if is_training:
142+
ds = (
143+
ds.shuffle(global_batch_size * 10)
163144
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
164145
.repeat(-1)
165146
.prefetch(AUTOTUNE)
166-
)
147+
)
148+
# For Evaluation
149+
else:
150+
ds = (
151+
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
152+
.prefetch(AUTOTUNE)
153+
)
167154

168-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
169-
return train_iter
155+
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
156+
return iter
157+
158+
def make_tfrecord_iterator(
159+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
160+
):
161+
"""Iterator for TFRecord format. For Laion dataset,
162+
check out preparation script
163+
maxdiffusion/pedagogical_examples/to_tfrecords.py
164+
"""
165+
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
166+
# TODO: refactor to support evaluation on all dataset format.
167+
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
168+
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def make_data_iterator(
5252
image_transforms_fn=None,
5353
feature_description=None,
5454
prepare_sample_fn=None,
55+
is_training=True,
5556
):
5657
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
5758

@@ -106,6 +107,7 @@ def make_data_iterator(
106107
global_batch_size,
107108
feature_description,
108109
prepare_sample_fn,
110+
is_training
109111
)
110112
else:
111113
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_data_shardings(self, mesh):
108108
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
109109
return data_sharding
110110

111-
def load_dataset(self, mesh):
111+
def load_dataset(self, mesh, is_training=True):
112112
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
113113
# Image pre-training - txt2img 256px
114114
# Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
@@ -141,6 +141,7 @@ def prepare_sample(features):
141141
config.global_batch_size_to_load,
142142
feature_description=feature_description,
143143
prepare_sample_fn=prepare_sample,
144+
is_training=is_training,
144145
)
145146
return data_iterator
146147

@@ -155,20 +156,20 @@ def start_training(self):
155156
del pipeline.vae_cache
156157

157158
mesh = pipeline.mesh
158-
data_iterator = self.load_dataset(mesh)
159+
train_data_iterator = self.load_dataset(mesh, is_training=True)
159160

160161
# Load FlowMatch scheduler
161162
scheduler, scheduler_state = self.create_scheduler()
162163
pipeline.scheduler = scheduler
163164
pipeline.scheduler_state = scheduler_state
164165
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
165166
# Returns pipeline with trained transformer state
166-
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
167+
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
167168

168169
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
169170
print_ssim(pretrained_video_path, posttrained_video_path)
170171

171-
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
172+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
172173
mesh = pipeline.mesh
173174
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
174175

@@ -203,6 +204,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
203204
out_shardings=(state_shardings, None, None, None),
204205
donate_argnums=(0,),
205206
)
207+
p_eval_step = jax.jit(
208+
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
209+
in_shardings=(state_shardings, data_shardings, None, None),
210+
out_shardings=(None, None),
211+
)
212+
206213
rng = jax.random.key(self.config.seed)
207214
start_step = 0
208215
last_step_completion = datetime.datetime.now()
@@ -219,13 +226,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
219226
per_device_tflops = self.calculate_tflops(pipeline)
220227

221228
scheduler_state = pipeline.scheduler_state
222-
example_batch = load_next_batch(data_iterator, None, self.config)
229+
example_batch = load_next_batch(train_data_iterator, None, self.config)
223230
with ThreadPoolExecutor(max_workers=1) as executor:
224231
for step in np.arange(start_step, self.config.max_train_steps):
225232
if self.config.enable_profiler and step == first_profiling_step:
226233
max_utils.activate_profiler(self.config)
227234
start_step_time = datetime.datetime.now()
228-
next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config)
235+
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
229236
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
230237
self.config.logical_axis_rules
231238
):
@@ -241,6 +248,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
241248
)
242249
if self.config.write_metrics:
243250
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
251+
252+
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
253+
# Re-create the iterator each time you start evaluation to reset it
254+
# This assumes your data loading logic can be called to get a fresh iterator.
255+
eval_data_iterator = self.load_dataset(mesh, is_training=False)
256+
eval_rng = jax.random.key(self.config.seed + step)
257+
eval_metrics = []
258+
# Loop indefinitely until the iterator is exhausted
259+
while True:
260+
try:
261+
with mesh:
262+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
263+
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
264+
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
265+
except StopIteration:
266+
# This block is executed when the iterator has no more data
267+
break
268+
# Check if any evaluation was actually performed
269+
if eval_metrics:
270+
eval_loss = jnp.mean(jnp.array(eval_metrics))
271+
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
272+
if writer:
273+
writer.add_scalar("learning/eval_loss", eval_loss, step)
274+
else:
275+
max_logging.log(f"Step {step}, evaluation dataset was empty.")
244276
example_batch = next_batch_future.result()
245277

246278
_metrics_queue.put(None)
@@ -296,3 +328,62 @@ def loss_fn(params):
296328
new_state = state.apply_gradients(grads=grads)
297329
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
298330
return new_state, scheduler_state, metrics, new_rng
331+
332+
def eval_step(state, data, rng, scheduler_state, scheduler, config):
333+
"""
334+
Computes the evaluation loss for a single batch without updating model weights.
335+
"""
336+
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
337+
338+
# This ensures the batch size is consistent, though it might be redundant
339+
# if the evaluation dataloader is already configured correctly.
340+
for k, v in data.items():
341+
data[k] = v[: config.global_batch_size_to_train_on, :]
342+
343+
# The loss function logic is identical to training. We are evaluating the model's
344+
# ability to perform its core training objective (e.g., denoising).
345+
def loss_fn(params):
346+
# Reconstruct the model from its definition and parameters
347+
model = nnx.merge(state.graphdef, params, state.rest_of_state)
348+
349+
# Prepare inputs
350+
latents = data["latents"].astype(config.weights_dtype)
351+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
352+
bsz = latents.shape[0]
353+
354+
# Sample random timesteps and noise, just as in a training step
355+
timesteps = jax.random.randint(
356+
timestep_rng,
357+
(bsz,),
358+
0,
359+
scheduler.config.num_train_timesteps,
360+
)
361+
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
362+
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
363+
364+
# Get the model's prediction
365+
model_pred = model(
366+
hidden_states=noisy_latents,
367+
timestep=timesteps,
368+
encoder_hidden_states=encoder_hidden_states,
369+
)
370+
371+
# Calculate the loss against the target
372+
training_target = scheduler.training_target(latents, noise, timesteps)
373+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
374+
loss = (training_target - model_pred) ** 2
375+
loss = loss * training_weight
376+
loss = jnp.mean(loss)
377+
378+
return loss
379+
380+
# --- Key Difference from train_step ---
381+
# Directly compute the loss without calculating gradients.
382+
# The model's state.params are used but not updated.
383+
loss = loss_fn(state.params)
384+
385+
# Structure the metrics for logging and aggregation
386+
metrics = {"scalar": {"learning/eval_loss": loss}}
387+
388+
# Return the computed metrics and the new RNG key for the next eval step
389+
return metrics, new_rng

0 commit comments

Comments
 (0)