diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1dd81b075..9ef1b72d6 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -141,14 +141,15 @@ ici_tensor_parallelism: 1 # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tf' +dataset_type: 'tfrecord' cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory # prepare image latents and text encoder outputs # Reduce memory consumption and reduce step time during training # transformed dataset is saved at dataset_save_location -dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +dataset_save_location: '' +load_tfrecord_cached: True train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' @@ -185,6 +186,10 @@ per_device_batch_size: 1 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 + warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/data_preprocessing/__init__.py b/src/maxdiffusion/data_preprocessing/__init__.py new file mode 100644 index 000000000..7e4185f36 --- /dev/null +++ b/src/maxdiffusion/data_preprocessing/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ diff --git a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py new file mode 100644 index 000000000..ae0b15f47 --- /dev/null +++ b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py @@ -0,0 +1,153 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +""" +Prepare tfrecords with latents and text embeddings preprocessed. +1. Download the dataset +""" + +import os +import functools +from absl import app +from typing import Sequence, Union, List +from datasets import load_dataset +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from maxdiffusion import pyconfig, max_utils +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion.video_processor import VideoProcessor + +import tensorflow as tf + + +def image_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) + + +def bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) + + +def float_feature(value): + """Returns a float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) + + +def int64_feature(value): + """Returns an int64_list from a bool / enum / int / uint.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def float_feature_list(value): + """Returns a list of float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + +def create_example(latent, hidden_states): + latent = tf.io.serialize_tensor(latent) + hidden_states = tf.io.serialize_tensor(hidden_states) + feature = { + "latents": bytes_feature(latent), + "encoder_hidden_states": bytes_feature(hidden_states), + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + return example.SerializeToString() + + +def text_encode(pipeline, prompt: Union[str, List[str]]): + encoder_hidden_states = pipeline._get_t5_prompt_embeds(prompt) + encoder_hidden_states = encoder_hidden_states.detach().numpy() + return encoder_hidden_states + + +def vae_encode(video, rng, vae, vae_cache): + latent = vae.encode(video, feat_cache=vae_cache) + latent = latent.latent_dist.sample(rng) + return latent + + +def generate_dataset(config, pipeline): + + tfrecords_dir = config.tfrecords_dir + if not os.path.exists(tfrecords_dir): + os.makedirs(tfrecords_dir) + + tf_rec_num = 0 + no_records_per_shard = config.no_records_per_shard + global_record_count = 0 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + + # create mesh + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + + vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample) + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) + + # jit vae fun. + p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache)) + + # Load dataset + ds = load_dataset(config.dataset_name, split="train") + ds = ds.shuffle(seed=config.seed) + ds = ds.select_columns([config.caption_column, config.image_column]) + batch_size = 10 + for i in range(0, len(ds), batch_size): + rng, new_rng = jax.random.split(rng) + text = ds[i : i + batch_size]["text"] + videos = ds[i : i + batch_size]["image"] + + videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos] + video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype) + with mesh: + latents = p_vae_encode(video=video, rng=new_rng) + latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) + encoder_hidden_states = text_encode(pipeline, text) + for latent, encoder_hidden_state in zip(latents, encoder_hidden_states): + writer.write(create_example(latent, encoder_hidden_state)) + shard_record_count += 1 + global_record_count += 1 + + if shard_record_count >= no_records_per_shard: + writer.close() + tf_rec_num += 1 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + + +def run(config): + pipeline = WanPipeline.from_pretrained(config, load_transformer=False) + # Don't need the transformer for preprocessing. + generate_dataset(config, pipeline) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 760d655cc..8486c79d5 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -21,9 +21,10 @@ from maxdiffusion.utils import export_to_video -def run(config): +def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) - pipeline = WanPipeline.from_pretrained(config) + if pipeline is None: + pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() # Skip layer guidance @@ -59,7 +60,7 @@ def run(config): print("compile time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps) s0 = time.perf_counter() videos = pipeline( prompt=prompt, diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 6b588ed2d..454f65785 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -73,35 +73,18 @@ def make_tf_iterator( train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter + def make_cached_tfrecord_iterator( - config, - dataloading_host_index, - dataloading_host_count, - mesh, - global_batch_size, + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn ): """ New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings: latents, input_ids, prompt_embeds, and text_embeds. """ - feature_description = { - "pixel_values": tf.io.FixedLenFeature([], tf.string), - "input_ids": tf.io.FixedLenFeature([], tf.string), - "prompt_embeds": tf.io.FixedLenFeature([], tf.string), - "text_embeds": tf.io.FixedLenFeature([], tf.string), - } def _parse_tfrecord_fn(example): return tf.io.parse_single_example(example, feature_description) - def prepare_sample(features): - pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32) - input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32) - prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) - text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) - - return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds} - # This pipeline reads the sharded files and applies the parsing and preparation. filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) @@ -109,7 +92,7 @@ def prepare_sample(features): tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) .shard(num_shards=dataloading_host_count, index=dataloading_host_index) .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(prepare_sample, num_parallel_calls=AUTOTUNE) + .map(prepare_sample_fn, num_parallel_calls=AUTOTUNE) .shuffle(global_batch_size * 10) .batch(global_batch_size // dataloading_host_count, drop_remainder=True) .repeat(-1) @@ -123,11 +106,7 @@ def prepare_sample(features): # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py def make_tfrecord_iterator( - config, - dataloading_host_index, - dataloading_host_count, - mesh, - global_batch_size, + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn ): """Iterator for TFRecord format. For Laion dataset, check out preparation script @@ -136,12 +115,22 @@ def make_tfrecord_iterator( # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. - # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. - if (config.cache_latents_text_encoder_outputs + # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + if ( + config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location) - and 'load_tfrecord_cached'in config.get_keys() - and config.load_tfrecord_cached): - return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size) + and "load_tfrecord_cached" in config.get_keys() + and config.load_tfrecord_cached + ): + return make_cached_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn, + ) feature_description = { "moments": tf.io.FixedLenFeature([], tf.string), diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 3a78ff09b..0c1c68602 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -50,8 +50,25 @@ def make_data_iterator( global_batch_size, tokenize_fn=None, image_transforms_fn=None, + feature_description=None, + prepare_sample_fn=None, ): """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" + + if config.dataset_type == "hf" or config.dataset_type == "tf": + if tokenize_fn is None or image_transforms_fn is None: + raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn") + + if ( + config.dataset_type == "tfrecord" + and config.cache_latents_text_encoder_outputs + and feature_description is None + and prepare_sample_fn is None + ): + raise ValueError( + f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True." + ) + if config.dataset_type == "hf": return _hf_data_processing.make_hf_streaming_iterator( config, @@ -87,6 +104,8 @@ def make_data_iterator( dataloading_host_count, mesh, global_batch_size, + feature_description, + prepare_sample_fn, ) else: assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index a3be8e138..db4b25fb2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -41,6 +41,7 @@ def basic_clean(text): if is_ftfy_available(): import ftfy + text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() @@ -221,7 +222,7 @@ def load_scheduler(cls, config): return scheduler, scheduler_state @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False): + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) @@ -232,8 +233,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False): scheduler_state = None text_encoder = None if not vae_only: - with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + if load_transformer: + with mesh: + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -397,7 +399,7 @@ def __call__( num_channels_latents=num_channel_latents, ) - data_sharding = NamedSharding(self.devices_array, P()) + data_sharding = NamedSharding(self.mesh, P()) if len(prompt) % jax.device_count() == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index edd249de1..26743e3d7 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -43,7 +43,7 @@ _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"] - _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] + _import_structure["scheduling_flow_match_flax"] = ["FlaxFlowMatchScheduler"] _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"] _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] @@ -70,6 +70,7 @@ from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler + from .scheduling_flow_match_flax import FlowMatchScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py new file mode 100644 index 000000000..1f9c3a78e --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py @@ -0,0 +1,293 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This is a JAX/Flax conversion of a PyTorch implementation. +# The original PyTorch code was provided by the user. + +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +# Assuming these are part of your project structure, similar to the UniPC example +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + FlaxSchedulerMixin, + FlaxSchedulerOutput, +) + + +@flax.struct.dataclass +class FlowMatchSchedulerState: + """ + Data class to hold the mutable state of the FlaxFlowMatchScheduler. + """ + + sigmas: jnp.ndarray + timesteps: jnp.ndarray + linear_timesteps_weights: Optional[jnp.ndarray] + training: bool + num_inference_steps: int # Store for training weight calculation + + @classmethod + def create(cls): + return cls( + sigmas=None, + timesteps=None, + linear_timesteps_weights=None, + training=False, + num_inference_steps=0, + ) + + +@flax.struct.dataclass(frozen=False) +class FlaxFlowMatchSchedulerOutput(FlaxSchedulerOutput): + """ + Output class for the JAX FlowMatchScheduler's step function. + + Attributes: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The computed sample at the previous timestep. + state (`FlowMatchSchedulerState`): + The updated scheduler state. + """ + + state: FlowMatchSchedulerState + + +class FlaxFlowMatchScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + FlaxFlowMatchScheduler is a JAX/Flax conversion of a scheduler used for training video generation models like + WAN 2.1. It operates based on a "flow matching" paradigm. + + This scheduler directly calculates sigmas for a continuous-time diffusion process, which can be beneficial for + certain types of models and training schemes. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 3.0, + sigma_max: float = 1.0, + sigma_min: float = 0.003 / 1.002, + inverse_timesteps: bool = False, + extra_one_step: bool = False, + reverse_sigmas: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self) -> FlowMatchSchedulerState: + """Creates the initial state for the scheduler.""" + + return FlowMatchSchedulerState.create() + + def set_timesteps( + self, + state: FlowMatchSchedulerState, + num_inference_steps: int = 100, + shape: Tuple = None, # Not used but part of the standard API + denoising_strength: float = 1.0, + training: bool = False, + shift: Optional[float] = None, + ) -> FlowMatchSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + The shape of the samples. + denoising_strength (`float`): + The strength of the denoising process. + training (`bool`): + Whether the scheduler is being used for training. + shift (`Optional[float]`): + An optional shift value to override the one in the config. + + Returns: + `FlowMatchSchedulerState`: The updated scheduler state. + """ + current_shift = shift if shift is not None else self.config.shift + sigma_start = self.config.sigma_min + (self.config.sigma_max - self.config.sigma_min) * denoising_strength + + if self.config.extra_one_step: + sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps + 1, dtype=self.dtype)[:-1] + else: + sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps, dtype=self.dtype) + + if self.config.inverse_timesteps: + sigmas = jnp.flip(sigmas, dims=[0]) + + sigmas = current_shift * sigmas / (1 + (current_shift - 1) * sigmas) + + if self.config.reverse_sigmas: + sigmas = 1 - sigmas + + timesteps = sigmas * self.config.num_train_timesteps + + linear_timesteps_weights = None + if training: + x = timesteps + y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - jnp.min(y) + bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted)) + linear_timesteps_weights = bsmntw_weighing + + return state.replace( + sigmas=sigmas, + timesteps=timesteps, + linear_timesteps_weights=linear_timesteps_weights, + training=training, + num_inference_steps=num_inference_steps, + ) + + def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: + """Finds the index of the closest timestep in the scheduler's `timesteps` array.""" + timestep = jnp.asarray(timestep, dtype=state.timesteps.dtype) + if timestep.ndim == 0: + return jnp.argmin(jnp.abs(state.timesteps - timestep)) + else: + diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None]) + return jnp.argmin(diffs, axis=1) + + def step( + self, + state: FlowMatchSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + to_final: bool = False, + return_dict: bool = True, + ) -> Union[FlaxFlowMatchSchedulerOutput, Tuple]: + """ + Propagates the sample with the flow matching scheduler. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + model_output (`jnp.ndarray`): + The direct output from the learned diffusion model. + timestep (`jnp.ndarray`): + The current timestep in the diffusion chain. + sample (`jnp.ndarray`): + The current sample (e.g. noisy latents). + to_final (`bool`): + Whether this is the final step. + return_dict (`bool`): + Whether to return a `FlaxFlowMatchSchedulerOutput` object. + + Returns: + `FlaxFlowMatchSchedulerOutput` or `tuple`: A tuple (`prev_sample`, `state`) or a + `FlaxFlowMatchSchedulerOutput` object containing the previous sample and the updated state. + """ + timestep_id = self._find_timestep_id(state, timestep) + sigma = state.sigmas[timestep_id] + + def get_next_sigma(): + return state.sigmas[timestep_id + 1] + + def get_final_sigma(): + return jnp.array(1.0 if (self.config.inverse_timesteps or self.config.reverse_sigmas) else 0.0, dtype=sigma.dtype) + + is_final_step = to_final or jnp.all(timestep_id + 1 >= state.timesteps.shape[0]) + sigma_next = jax.lax.cond(is_final_step, get_final_sigma, get_next_sigma) + + if jnp.ndim(timestep) != 0: + broadcast_shape = (-1,) + (1,) * (sample.ndim - 1) + sigma = sigma.reshape(broadcast_shape) + sigma_next = sigma_next.reshape(broadcast_shape) + + prev_sample = sample + model_output * (sigma_next - sigma) + + if not return_dict: + return (prev_sample, state) + + return FlaxFlowMatchSchedulerOutput(prev_sample=prev_sample, state=state) + + def return_to_timestep( + self, state: FlowMatchSchedulerState, timestep: jnp.ndarray, sample: jnp.ndarray, sample_stablized: jnp.ndarray + ) -> jnp.ndarray: + """Calculates the model output required to go from a stabilized sample back to the original sample.""" + timestep_id = self._find_timestep_id(state, timestep) + sigma = state.sigmas[timestep_id] + + if jnp.ndim(timestep) != 0: + sigma = sigma.reshape((-1,) + (1,) * (sample.ndim - 1)) + + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise( + self, + state: FlowMatchSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + """ + Adds noise to the original samples according to the flow matching schedule. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + original_samples (`jnp.ndarray`): + The original clean samples. + noise (`jnp.ndarray`): + The noise to add to the samples. + timesteps (`jnp.ndarray`): + The timesteps that correspond to the noise levels. + + Returns: + `jnp.ndarray`: The noisy samples. + """ + if state.sigmas is None or state.timesteps is None: + raise ValueError("Scheduler's `sigmas` and `timesteps` are not set. Please call `set_timesteps` before `add_noise`.") + + timestep_ids = self._find_timestep_id(state, timesteps) + sigmas = state.sigmas[timestep_ids] + + broadcast_shape = (-1,) + (1,) * (original_samples.ndim - 1) + sigmas = sigmas.reshape(broadcast_shape) + + noisy_samples = (1 - sigmas) * original_samples + sigmas * noise + return noisy_samples + + def training_target(self, sample: jnp.ndarray, noise: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: + """ + Calculates the training target. For flow matching, this is typically the velocity, `x_1 - x_0`, + which is equivalent to `noise - sample` under this scheduler's `add_noise` definition. + """ + target = noise - sample + return target + + def training_weight(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: + """Calculates the training weight for a given timestep.""" + timestep_ids = self._find_timestep_id(state, timestep) + weights = state.linear_timesteps_weights[timestep_ids] + return weights + + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 9afd22674..79e7a0891 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -506,7 +506,23 @@ def test_make_laion_tfrecord_iterator(self): from_pt=config.from_pt, ) - train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) + feature_description = { + "moments": tf.io.FixedLenFeature([], tf.string), + "clip_embeddings": tf.io.FixedLenFeature([], tf.string), + } + + def _parse_tfrecord_fn(example): + return tf.io.parse_single_example(example, feature_description) + + train_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + global_batch_size, + feature_description=feature_description, + prepare_sample_fn=_parse_tfrecord_fn, + ) data = next(train_iterator) device_count = jax.device_count() diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 4cc81955b..a68cc6170 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -20,6 +20,7 @@ import threading import time import numpy as np +import tensorflow as tf import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -140,6 +141,26 @@ def load_dataset(self, pipeline, params, train_states): p_vae_apply=p_vae_apply, ) + feature_description = { + "pixel_values": tf.io.FixedLenFeature([], tf.string), + "input_ids": tf.io.FixedLenFeature([], tf.string), + "prompt_embeds": tf.io.FixedLenFeature([], tf.string), + "text_embeds": tf.io.FixedLenFeature([], tf.string), + } + + def prepare_sample(features): + pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32) + input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32) + prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) + text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) + + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "prompt_embeds": prompt_embeds, + "text_embeds": text_embeds, + } + data_iterator = make_data_iterator( config, jax.process_index(), @@ -148,6 +169,8 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size, tokenize_fn=tokenize_fn, image_transforms_fn=image_transforms_fn, + feature_description=feature_description, + prepare_sample_fn=prepare_sample, ) return data_iterator diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index b11626c27..6af138bc7 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -18,15 +18,26 @@ import datetime import functools import numpy as np +import threading +from concurrent.futures import ThreadPoolExecutor +import tensorflow as tf import jax.numpy as jnp import jax -import jax.tree_util as jtu from flax import nnx +from ..schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning -from ..schedulers import FlaxEulerDiscreteScheduler -from .. import max_utils, max_logging, train_utils, maxdiffusion_utils +from .. import max_utils, max_logging, train_utils from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) -from maxdiffusion.multihost_dataloading import _form_global_array +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) +from maxdiffusion.generate_wan import run as generate_wan +from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue) + + +def generate_sample(config, pipeline, filename_prefix): + """ + Generates a video to validate training did not corrupt the model + """ + generate_wan(config, pipeline, filename_prefix) class WanTrainer(WanCheckpointer): @@ -41,21 +52,18 @@ def __init__(self, config): def post_training_steps(self, pipeline, params, train_states, msg=""): pass - def create_scheduler(self, pipeline, params): - # TODO - set right scheduler - noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( - pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 - ) - noise_scheduler_state = noise_scheduler.set_timesteps( - state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" - ) + def create_scheduler(self): + """Creates and initializes the Flow Match scheduler for training.""" + noise_scheduler = FlaxFlowMatchScheduler(dtype=jnp.float32) + noise_scheduler_state = noise_scheduler.create_state() + noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True) return noise_scheduler, noise_scheduler_state def calculate_tflops(self, pipeline): max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") return 0 - def load_dataset(self, pipeline): + def load_dataset(self, mesh): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 @@ -64,24 +72,59 @@ def load_dataset(self, pipeline): # prompt embeds shape: (1, 512, 4096) # For now, we will pass the same latents over and over # TODO - create a dataset - return maxdiffusion_utils.get_dummy_wan_inputs(self.config, pipeline, self.global_batch_size) + config = self.config + if config.dataset_type != "tfrecord" and not config.cache_latents_text_encoder_outputs: + raise ValueError( + "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" + ) + + feature_description = { + "latents": tf.io.FixedLenFeature([], tf.string), + "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), + } + + def prepare_sample(features): + latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) + encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) + return {"latents": latents, "encoder_hidden_states": encoder_hidden_states} + + data_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + self.global_batch_size, + feature_description=feature_description, + prepare_sample_fn=prepare_sample, + ) + return data_iterator def start_training(self): pipeline = self.load_checkpoint() - del pipeline.vae - dummy_inputs = self.load_dataset(pipeline) + # del pipeline.vae + + # Generate a sample before training to compare against generated sample after training. + generate_sample(self.config, pipeline, filename_prefix="pre-training-") mesh = pipeline.mesh + data_iterator = self.load_dataset(mesh) + + # Load FlowMatch scheduler + scheduler, scheduler_state = self.create_scheduler() + pipeline.scheduler = scheduler + pipeline.scheduler_state = scheduler_state + optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) - dummy_inputs = tuple( - [jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs] - ) - self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) + self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): graphdef, state = nnx.split((pipeline.transformer, optimizer)) + writer = max_utils.initialize_summary_writer(self.config) + writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) + writer_thread.start() + num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) @@ -95,7 +138,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): state = state.to_pure_dict() p_train_step = jax.jit( - train_step, + functools.partial(train_step, scheduler=pipeline.scheduler), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -113,52 +156,77 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): start_step = 0 per_device_tflops = self.calculate_tflops(pipeline) - for step in np.arange(start_step, self.config.max_train_steps): - if self.config.enable_profiler and step == first_profiling_step: - max_utils.activate_profiler(self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): - state, train_metric, rng = p_train_step(state, graphdef, data, rng) + scheduler_state = pipeline.scheduler_state + example_batch = load_next_batch(data_iterator, None, self.config) + with ThreadPoolExecutor(max_workers=1) as executor: + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + start_step_time = datetime.datetime.now() + next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config) + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): + state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, example_batch, rng) + train_metric["scalar"]["learning/loss"].block_until_ready() + last_step_completion = datetime.datetime.now() - new_time = datetime.datetime.now() + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) - if self.config.enable_profiler and step == last_profiling_step: - max_utils.deactivate_profiler(self.config) + train_utils.record_scalar_metrics( + train_metric, last_step_completion - start_step_time, per_device_tflops, learning_rate_scheduler(step) + ) + if self.config.write_metrics: + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + example_batch = next_batch_future.result() - train_utils.record_scalar_metrics( - train_metric, new_time - last_step_completion, per_device_tflops, learning_rate_scheduler(step) - ) - if self.config.write_metrics: - train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) - last_step_completion = new_time + _metrics_queue.put(None) + writer_thread.join() + if writer: + writer.flush() + # load new state for trained tranformer + graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) + pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state) -def train_step(state, graphdef, data, rng): - return step_optimizer(graphdef, state, data, rng) + generate_sample(self.config, pipeline, filename_prefix="post-training-") -def step_optimizer(graphdef, state, data, rng): - _, new_rng = jax.random.split(rng) +def train_step(state, graphdef, scheduler_state, data, rng, scheduler): + return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng) - def loss_fn(model): - latents, prompt_embeds, timesteps = data - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) +def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng): + _, new_rng, timestep_rng = jax.random.split(rng, num=3) - # TODO - add noise here + def loss_fn(model): + latents = data["latents"] + encoder_hidden_states = data["encoder_hidden_states"] + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + scheduler.config.num_train_timesteps, + ) + noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) model_pred = model( - hidden_states=noise, + hidden_states=noisy_latents, timestep=timesteps, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, is_uncond=jnp.array(False, dtype=jnp.bool_), slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) - target = noise - latents - loss = (target - model_pred) ** 2 + + training_target = scheduler.training_target(latents, noise, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = (training_target - model_pred) ** 2 + loss = loss * training_weight loss = jnp.mean(loss) - # breakpoint() + return loss model, optimizer = nnx.merge(graphdef, state) @@ -167,4 +235,4 @@ def loss_fn(model): state = nnx.state((model, optimizer)) state = state.to_pure_dict() metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - return state, metrics, new_rng + return state, scheduler_state, metrics, new_rng