Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,6 @@ use_qwix_quantization: False # Whether to use qwix for quantization. If set to T
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"

# Eval model on per eval_every steps. -1 means don't eval.
eval_every: -1
eval_data_dir: ""
127 changes: 63 additions & 64 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tensorflow.experimental.numpy as tnp
from datasets import load_dataset, load_from_disk
import jax
from maxdiffusion import multihost_dataloading
from maxdiffusion import multihost_dataloading, max_logging

AUTOTUNE = tf.data.AUTOTUNE

Expand Down Expand Up @@ -78,92 +78,91 @@ def make_tf_iterator(
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


def make_cached_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))

train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)

# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
def _make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.

# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)

if (
config.cache_latents_text_encoder_outputs
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
):
return make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
feature_description,
prepare_sample_fn,
)
# Determine whether to use the "cached" dataset, which requires externally
# provided parsing functions, or the default one with its internal parsing logic.
make_cached_tfrecord_iterator = (
config.cache_latents_text_encoder_outputs
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
)

feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}

used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)
return tf.io.parse_single_example(example, used_feature_description)

def prepare_sample(features):
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
return {"pixel_values": moments, "input_ids": clip_embeddings}

filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)

# --- PADDING LOGIC FOR EVALUATION ---
if not is_training:
num_eval_samples = 0
for _ in ds:
num_eval_samples += 1

remainder = num_eval_samples % global_batch_size
if remainder != 0:
num_to_pad = global_batch_size - remainder
# Create a dataset of padding samples from the beginning
padding_ds = ds.take(num_to_pad)
# Add the padding samples to the end
ds = ds.concatenate(padding_ds)
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")

used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
ds = (
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
)
if is_training:
ds = (
ds.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)
)
# For Evaluation
else:
ds = (
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
.prefetch(AUTOTUNE)
)

train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
return iter

def make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
# TODO: refactor to support evaluation on all dataset format.
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)
2 changes: 2 additions & 0 deletions src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def make_data_iterator(
image_transforms_fn=None,
feature_description=None,
prepare_sample_fn=None,
is_training=True,
):
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""

Expand Down Expand Up @@ -106,6 +107,7 @@ def make_data_iterator(
global_batch_size,
feature_description,
prepare_sample_fn,
is_training
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
Expand Down
103 changes: 97 additions & 6 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_data_shardings(self, mesh):
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
return data_sharding

def load_dataset(self, mesh):
def load_dataset(self, mesh, is_training=True):
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
# Image pre-training - txt2img 256px
# Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
Expand Down Expand Up @@ -134,6 +134,7 @@ def prepare_sample(features):
config.global_batch_size_to_load,
feature_description=feature_description,
prepare_sample_fn=prepare_sample,
is_training=is_training,
)
return data_iterator

Expand All @@ -145,20 +146,20 @@ def start_training(self):
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
mesh = pipeline.mesh
data_iterator = self.load_dataset(mesh)
train_data_iterator = self.load_dataset(mesh, is_training=True)

# Load FlowMatch scheduler
scheduler, scheduler_state = self.create_scheduler()
pipeline.scheduler = scheduler
pipeline.scheduler_state = scheduler_state
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
# Returns pipeline with trained transformer state
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)

posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
print_ssim(pretrained_video_path, posttrained_video_path)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
mesh = pipeline.mesh
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)

Expand Down Expand Up @@ -193,6 +194,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
out_shardings=(state_shardings, None, None, None),
donate_argnums=(0,),
)
p_eval_step = jax.jit(
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
in_shardings=(state_shardings, data_shardings, None, None),
out_shardings=(None, None),
)

rng = jax.random.key(self.config.seed)
start_step = 0
last_step_completion = datetime.datetime.now()
Expand All @@ -209,13 +216,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
per_device_tflops = self.calculate_tflops(pipeline)

scheduler_state = pipeline.scheduler_state
example_batch = load_next_batch(data_iterator, None, self.config)
example_batch = load_next_batch(train_data_iterator, None, self.config)
with ThreadPoolExecutor(max_workers=1) as executor:
for step in np.arange(start_step, self.config.max_train_steps):
if self.config.enable_profiler and step == first_profiling_step:
max_utils.activate_profiler(self.config)
start_step_time = datetime.datetime.now()
next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config)
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
self.config.logical_axis_rules
):
Expand All @@ -231,6 +238,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
)
if self.config.write_metrics:
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)

if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
# Re-create the iterator each time you start evaluation to reset it
# This assumes your data loading logic can be called to get a fresh iterator.
eval_data_iterator = self.load_dataset(mesh, is_training=False)
eval_rng = jax.random.key(self.config.seed + step)
eval_metrics = []
# Loop indefinitely until the iterator is exhausted
while True:
try:
with mesh:
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
except StopIteration:
# This block is executed when the iterator has no more data
break
# Check if any evaluation was actually performed
if eval_metrics:
eval_loss = jnp.mean(jnp.array(eval_metrics))
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
if writer:
writer.add_scalar("learning/eval_loss", eval_loss, step)
else:
max_logging.log(f"Step {step}, evaluation dataset was empty.")
example_batch = next_batch_future.result()

_metrics_queue.put(None)
Expand Down Expand Up @@ -286,3 +318,62 @@ def loss_fn(params):
new_state = state.apply_gradients(grads=grads)
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
return new_state, scheduler_state, metrics, new_rng

def eval_step(state, data, rng, scheduler_state, scheduler, config):
"""
Computes the evaluation loss for a single batch without updating model weights.
"""
_, new_rng, timestep_rng = jax.random.split(rng, num=3)

# This ensures the batch size is consistent, though it might be redundant
# if the evaluation dataloader is already configured correctly.
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]

# The loss function logic is identical to training. We are evaluating the model's
# ability to perform its core training objective (e.g., denoising).
def loss_fn(params):
# Reconstruct the model from its definition and parameters
model = nnx.merge(state.graphdef, params, state.rest_of_state)

# Prepare inputs
latents = data["latents"].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
bsz = latents.shape[0]

# Sample random timesteps and noise, just as in a training step
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
scheduler.config.num_train_timesteps,
)
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

# Get the model's prediction
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)

# Calculate the loss against the target
training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)

return loss

# --- Key Difference from train_step ---
# Directly compute the loss without calculating gradients.
# The model's state.params are used but not updated.
loss = loss_fn(state.params)

# Structure the metrics for logging and aggregation
metrics = {"scalar": {"learning/eval_loss": loss}}

# Return the computed metrics and the new RNG key for the next eval step
return metrics, new_rng
Loading