diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 40a76c6cf..ea3abc773 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -291,3 +291,6 @@ use_qwix_quantization: False # Whether to use qwix for quantization. If set to T # Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 quantization_calibration_method: "absmax" +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 562d5c718..885d59ef6 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -19,7 +19,7 @@ import tensorflow.experimental.numpy as tnp from datasets import load_dataset, load_from_disk import jax -from maxdiffusion import multihost_dataloading +from maxdiffusion import multihost_dataloading, max_logging AUTOTUNE = tf.data.AUTOTUNE @@ -78,92 +78,91 @@ def make_tf_iterator( train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter - -def make_cached_tfrecord_iterator( - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn -): - """ - New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings: - latents, input_ids, prompt_embeds, and text_embeds. - """ - - def _parse_tfrecord_fn(example): - return tf.io.parse_single_example(example, feature_description) - - # This pipeline reads the sharded files and applies the parsing and preparation. - filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) - - train_ds = ( - tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) - .shard(num_shards=dataloading_host_count, index=dataloading_host_index) - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(prepare_sample_fn, num_parallel_calls=AUTOTUNE) - .shuffle(global_batch_size * 10) - .batch(global_batch_size // dataloading_host_count, drop_remainder=True) - .repeat(-1) - .prefetch(AUTOTUNE) - ) - - # This wraps the tf.data.Dataset for use in the multi-host JAX environment. - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) - return train_iter - - # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py -def make_tfrecord_iterator( - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn +def _make_tfrecord_iterator( + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool ): - """Iterator for TFRecord format. For Laion dataset, - check out preparation script - maxdiffusion/pedagogical_examples/to_tfrecords.py - """ # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + # if is_training is True, loads the training dataset. If False, loads the evaluation dataset. # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) - if ( - config.cache_latents_text_encoder_outputs - and is_dataset_dir_valid - and "load_tfrecord_cached" in config.get_keys() - and config.load_tfrecord_cached - ): - return make_cached_tfrecord_iterator( - config, - dataloading_host_index, - dataloading_host_count, - mesh, - global_batch_size, - feature_description, - prepare_sample_fn, - ) + # Determine whether to use the "cached" dataset, which requires externally + # provided parsing functions, or the default one with its internal parsing logic. + make_cached_tfrecord_iterator = ( + config.cache_latents_text_encoder_outputs + and is_dataset_dir_valid + and "load_tfrecord_cached" in config.get_keys() + and config.load_tfrecord_cached + ) feature_description = { "moments": tf.io.FixedLenFeature([], tf.string), "clip_embeddings": tf.io.FixedLenFeature([], tf.string), } + used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description + def _parse_tfrecord_fn(example): - return tf.io.parse_single_example(example, feature_description) + return tf.io.parse_single_example(example, used_feature_description) def prepare_sample(features): moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32) clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32) return {"pixel_values": moments, "input_ids": clip_embeddings} - filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) - train_ds = ( - tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) - .shard(num_shards=dataloading_host_count, index=dataloading_host_index) - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(prepare_sample, num_parallel_calls=AUTOTUNE) - .shuffle(global_batch_size * 10) + filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*")) + ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) + + # --- PADDING LOGIC FOR EVALUATION --- + if not is_training: + num_eval_samples = 0 + for _ in ds: + num_eval_samples += 1 + + remainder = num_eval_samples % global_batch_size + if remainder != 0: + num_to_pad = global_batch_size - remainder + # Create a dataset of padding samples from the beginning + padding_ds = ds.take(num_to_pad) + # Add the padding samples to the end + ds = ds.concatenate(padding_ds) + max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") + + used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample + ds = ( + ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) + .map(used_prepare_sample, num_parallel_calls=AUTOTUNE) + ) + if is_training: + ds = ( + ds.shuffle(global_batch_size * 10) .batch(global_batch_size // dataloading_host_count, drop_remainder=True) .repeat(-1) .prefetch(AUTOTUNE) - ) + ) + # For Evaluation + else: + ds = ( + ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False) + .prefetch(AUTOTUNE) + ) - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) - return train_iter + iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh) + return iter + +def make_tfrecord_iterator( + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training +): + """Iterator for TFRecord format. For Laion dataset, + check out preparation script + maxdiffusion/pedagogical_examples/to_tfrecords.py + """ + # Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset. + # TODO: refactor to support evaluation on all dataset format. + dataset_path = config.train_data_dir if is_training else config.eval_data_dir + return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training) diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 0c1c68602..e7014bbc3 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -52,6 +52,7 @@ def make_data_iterator( image_transforms_fn=None, feature_description=None, prepare_sample_fn=None, + is_training=True, ): """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" @@ -106,6 +107,7 @@ def make_data_iterator( global_batch_size, feature_description, prepare_sample_fn, + is_training ) else: assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 3b0b520bf..2818e4102 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -101,7 +101,7 @@ def get_data_shardings(self, mesh): data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding} return data_sharding - def load_dataset(self, mesh): + def load_dataset(self, mesh, is_training=True): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 @@ -134,6 +134,7 @@ def prepare_sample(features): config.global_batch_size_to_load, feature_description=feature_description, prepare_sample_fn=prepare_sample, + is_training=is_training, ) return data_iterator @@ -145,7 +146,7 @@ def start_training(self): # Generate a sample before training to compare against generated sample after training. pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") mesh = pipeline.mesh - data_iterator = self.load_dataset(mesh) + train_data_iterator = self.load_dataset(mesh, is_training=True) # Load FlowMatch scheduler scheduler, scheduler_state = self.create_scheduler() @@ -153,12 +154,12 @@ def start_training(self): pipeline.scheduler_state = scheduler_state optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) # Returns pipeline with trained transformer state - pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) + pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator) posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") print_ssim(pretrained_video_path, posttrained_video_path) - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator): mesh = pipeline.mesh graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) @@ -193,6 +194,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera out_shardings=(state_shardings, None, None, None), donate_argnums=(0,), ) + p_eval_step = jax.jit( + functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config), + in_shardings=(state_shardings, data_shardings, None, None), + out_shardings=(None, None), + ) + rng = jax.random.key(self.config.seed) start_step = 0 last_step_completion = datetime.datetime.now() @@ -209,13 +216,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera per_device_tflops = self.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state - example_batch = load_next_batch(data_iterator, None, self.config) + example_batch = load_next_batch(train_data_iterator, None, self.config) with ThreadPoolExecutor(max_workers=1) as executor: for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) start_step_time = datetime.datetime.now() - next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config) + next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( self.config.logical_axis_rules ): @@ -231,6 +238,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera ) if self.config.write_metrics: train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + + if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0: + # Re-create the iterator each time you start evaluation to reset it + # This assumes your data loading logic can be called to get a fresh iterator. + eval_data_iterator = self.load_dataset(mesh, is_training=False) + eval_rng = jax.random.key(self.config.seed + step) + eval_metrics = [] + # Loop indefinitely until the iterator is exhausted + while True: + try: + with mesh: + eval_batch = load_next_batch(eval_data_iterator, None, self.config) + metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) + eval_metrics.append(metrics["scalar"]["learning/eval_loss"]) + except StopIteration: + # This block is executed when the iterator has no more data + break + # Check if any evaluation was actually performed + if eval_metrics: + eval_loss = jnp.mean(jnp.array(eval_metrics)) + max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}") + if writer: + writer.add_scalar("learning/eval_loss", eval_loss, step) + else: + max_logging.log(f"Step {step}, evaluation dataset was empty.") example_batch = next_batch_future.result() _metrics_queue.put(None) @@ -286,3 +318,62 @@ def loss_fn(params): new_state = state.apply_gradients(grads=grads) metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} return new_state, scheduler_state, metrics, new_rng + +def eval_step(state, data, rng, scheduler_state, scheduler, config): + """ + Computes the evaluation loss for a single batch without updating model weights. + """ + _, new_rng, timestep_rng = jax.random.split(rng, num=3) + + # This ensures the batch size is consistent, though it might be redundant + # if the evaluation dataloader is already configured correctly. + for k, v in data.items(): + data[k] = v[: config.global_batch_size_to_train_on, :] + + # The loss function logic is identical to training. We are evaluating the model's + # ability to perform its core training objective (e.g., denoising). + def loss_fn(params): + # Reconstruct the model from its definition and parameters + model = nnx.merge(state.graphdef, params, state.rest_of_state) + + # Prepare inputs + latents = data["latents"].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) + bsz = latents.shape[0] + + # Sample random timesteps and noise, just as in a training step + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + scheduler.config.num_train_timesteps, + ) + noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) + + # Get the model's prediction + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + ) + + # Calculate the loss against the target + training_target = scheduler.training_target(latents, noise, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = (training_target - model_pred) ** 2 + loss = loss * training_weight + loss = jnp.mean(loss) + + return loss + + # --- Key Difference from train_step --- + # Directly compute the loss without calculating gradients. + # The model's state.params are used but not updated. + loss = loss_fn(state.params) + + # Structure the metrics for logging and aggregation + metrics = {"scalar": {"learning/eval_loss": loss}} + + # Return the computed metrics and the new RNG key for the next eval step + return metrics, new_rng