Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,6 @@ use_qwix_quantization: False # Whether to use qwix for quantization. If set to T
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"

# Eval model on per eval_every steps. -1 means don't eval.
eval_every: -1
eval_data_dir: ""
125 changes: 62 additions & 63 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,92 +78,91 @@ def make_tf_iterator(
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


def make_cached_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))

train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)

# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
def _make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.

# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)

if (
config.cache_latents_text_encoder_outputs
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
):
return make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
feature_description,
prepare_sample_fn,
)
# Determine whether to use the "cached" dataset, which requires externally
# provided parsing functions, or the default one with its internal parsing logic.
make_cached_tfrecord_iterator = (
config.cache_latents_text_encoder_outputs
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
)

feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}

used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)
return tf.io.parse_single_example(example, used_feature_description)

def prepare_sample(features):
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
return {"pixel_values": moments, "input_ids": clip_embeddings}

filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)

# --- PADDING LOGIC FOR EVALUATION ---
if not is_training:
num_eval_samples = 0
for _ in ds:
num_eval_samples += 1

remainder = num_eval_samples % global_batch_size
if remainder != 0:
num_to_pad = global_batch_size - remainder
# Create a dataset of padding samples from the beginning
padding_ds = ds.take(num_to_pad)
# Add the padding samples to the end
ds = ds.concatenate(padding_ds)
print(f"Padded evaluation dataset with {num_to_pad} samples.")
Comment thread
susanbao marked this conversation as resolved.
Outdated

used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
ds = (
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
)
if is_training:
ds = (
ds.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)
)
# For Evaluation
else:
ds = (
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
.prefetch(AUTOTUNE)
)

train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
return iter

def make_tfrecord_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
# TODO: refactor to support evaluation on all dataset format.
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)
2 changes: 2 additions & 0 deletions src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def make_data_iterator(
image_transforms_fn=None,
feature_description=None,
prepare_sample_fn=None,
is_training=True,
):
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""

Expand Down Expand Up @@ -106,6 +107,7 @@ def make_data_iterator(
global_batch_size,
feature_description,
prepare_sample_fn,
is_training
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
Expand Down
103 changes: 97 additions & 6 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_data_shardings(self, mesh):
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
return data_sharding

def load_dataset(self, mesh):
def load_dataset(self, mesh, is_training=True):
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
# Image pre-training - txt2img 256px
# Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
Expand Down Expand Up @@ -134,6 +134,7 @@ def prepare_sample(features):
config.global_batch_size_to_load,
feature_description=feature_description,
prepare_sample_fn=prepare_sample,
is_training=is_training,
)
return data_iterator

Expand All @@ -145,20 +146,20 @@ def start_training(self):
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
mesh = pipeline.mesh
data_iterator = self.load_dataset(mesh)
train_data_iterator = self.load_dataset(mesh, is_training=True)

# Load FlowMatch scheduler
scheduler, scheduler_state = self.create_scheduler()
pipeline.scheduler = scheduler
pipeline.scheduler_state = scheduler_state
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
# Returns pipeline with trained transformer state
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)

posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
print_ssim(pretrained_video_path, posttrained_video_path)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
mesh = pipeline.mesh
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)

Expand Down Expand Up @@ -193,6 +194,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
out_shardings=(state_shardings, None, None, None),
donate_argnums=(0,),
)
p_eval_step = jax.jit(
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
in_shardings=(state_shardings, data_shardings, None, None),
out_shardings=(None, None),
)

rng = jax.random.key(self.config.seed)
start_step = 0
last_step_completion = datetime.datetime.now()
Expand All @@ -209,13 +216,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
per_device_tflops = self.calculate_tflops(pipeline)

scheduler_state = pipeline.scheduler_state
example_batch = load_next_batch(data_iterator, None, self.config)
example_batch = load_next_batch(train_data_iterator, None, self.config)
with ThreadPoolExecutor(max_workers=1) as executor:
for step in np.arange(start_step, self.config.max_train_steps):
if self.config.enable_profiler and step == first_profiling_step:
max_utils.activate_profiler(self.config)
start_step_time = datetime.datetime.now()
next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config)
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
self.config.logical_axis_rules
):
Expand All @@ -231,6 +238,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
)
if self.config.write_metrics:
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)

if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
# Re-create the iterator each time you start evaluation to reset it
# This assumes your data loading logic can be called to get a fresh iterator.
eval_data_iterator = self.load_dataset(mesh, is_training=False)
eval_rng = jax.random.key(self.config.seed + step)
eval_metrics = []
# Loop indefinitely until the iterator is exhausted
while True:
try:
with mesh:
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
except StopIteration:
# This block is executed when the iterator has no more data
break
# Check if any evaluation was actually performed
if eval_metrics:
eval_loss = jnp.mean(jnp.array(eval_metrics))
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
if writer:
writer.add_scalar("learning/eval_loss", eval_loss, step)
else:
max_logging.log(f"Step {step}, evaluation dataset was empty.")
example_batch = next_batch_future.result()

_metrics_queue.put(None)
Expand Down Expand Up @@ -286,3 +318,62 @@ def loss_fn(params):
new_state = state.apply_gradients(grads=grads)
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
return new_state, scheduler_state, metrics, new_rng

def eval_step(state, data, rng, scheduler_state, scheduler, config):
"""
Computes the evaluation loss for a single batch without updating model weights.
"""
_, new_rng, timestep_rng = jax.random.split(rng, num=3)

# This ensures the batch size is consistent, though it might be redundant
# if the evaluation dataloader is already configured correctly.
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]

# The loss function logic is identical to training. We are evaluating the model's
# ability to perform its core training objective (e.g., denoising).
def loss_fn(params):
# Reconstruct the model from its definition and parameters
model = nnx.merge(state.graphdef, params, state.rest_of_state)

# Prepare inputs
latents = data["latents"].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
bsz = latents.shape[0]

# Sample random timesteps and noise, just as in a training step
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
scheduler.config.num_train_timesteps,
)
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

# Get the model's prediction
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)

# Calculate the loss against the target
training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)

return loss

# --- Key Difference from train_step ---
# Directly compute the loss without calculating gradients.
# The model's state.params are used but not updated.
loss = loss_fn(state.params)

# Structure the metrics for logging and aggregation
metrics = {"scalar": {"learning/eval_loss": loss}}

# Return the computed metrics and the new RNG key for the next eval step
return metrics, new_rng
Loading