From 1536f42ef2e5b34a7e95fdb9faeed5784988d6fe Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 25 Jun 2025 02:01:33 +0000 Subject: [PATCH 1/4] flow match scheduler + data to tf records --- src/maxdiffusion/configs/base_wan_14b.yml | 4 + .../data_preprocessing/__init__.py | 15 + .../wan_txt2vid_data_preprocessing.py | 149 +++++++++ .../pipelines/wan/wan_pipeline.py | 7 +- src/maxdiffusion/schedulers/__init__.py | 3 +- .../schedulers/scheduling_flow_match_flax.py | 294 ++++++++++++++++++ src/maxdiffusion/trainers/wan_trainer.py | 47 +-- 7 files changed, 496 insertions(+), 23 deletions(-) create mode 100644 src/maxdiffusion/data_preprocessing/__init__.py create mode 100644 src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py create mode 100644 src/maxdiffusion/schedulers/scheduling_flow_match_flax.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1dd81b075..073c90638 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -185,6 +185,10 @@ per_device_batch_size: 1 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 + warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/data_preprocessing/__init__.py b/src/maxdiffusion/data_preprocessing/__init__.py new file mode 100644 index 000000000..55bca151a --- /dev/null +++ b/src/maxdiffusion/data_preprocessing/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ \ No newline at end of file diff --git a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py new file mode 100644 index 000000000..19e4b52f1 --- /dev/null +++ b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py @@ -0,0 +1,149 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +""" +Prepare tfrecords with latents and text embeddings preprocessed. +1. Download the dataset +""" + +import os +import functools +from absl import app +from typing import Sequence, Union, List +from datasets import load_dataset +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from maxdiffusion import pyconfig, max_utils +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion.video_processor import VideoProcessor + +import tensorflow as tf + +def image_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) + + +def bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) + + +def float_feature(value): + """Returns a float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) + + +def int64_feature(value): + """Returns an int64_list from a bool / enum / int / uint.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def float_feature_list(value): + """Returns a list of float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def create_example(latent, hidden_states): + latent = tf.io.serialize_tensor(latent) + hidden_states = tf.io.serialize_tensor(hidden_states) + feature = { + "latents": bytes_feature(latent), + "encoder_hidden_states": bytes_feature(hidden_states), + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + return example.SerializeToString() + + +def text_encode(pipeline, prompt: Union[str, List[str]]): + encoder_hidden_states = pipeline._get_t5_prompt_embeds(prompt) + encoder_hidden_states = encoder_hidden_states.detach().numpy() + return encoder_hidden_states + +def vae_encode(video, rng, vae, vae_cache): + latent = vae.encode(video, feat_cache=vae_cache) + latent = latent.latent_dist.sample(rng) + return latent + +def generate_dataset(config, pipeline): + + tfrecords_dir = config.tfrecords_dir + if not os.path.exists(tfrecords_dir): + os.makedirs(tfrecords_dir) + + tf_rec_num = 0 + no_records_per_shard = config.no_records_per_shard + global_record_count = 0 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + + # create mesh + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + + vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample) + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) + + # jit vae fun. + p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache)) + + # Load dataset + ds = load_dataset(config.dataset_name, split='train') + ds = ds.shuffle(seed=config.seed) + ds = ds.select_columns([config.caption_column, config.image_column]) + batch_size = 10 + for i in range(0, len(ds), batch_size): + rng, new_rng = jax.random.split(rng) + text = ds[i:i+batch_size]['text'] + video = ds[i:i+batch_size]['image'] + + video = [np.expand_dims(np.array(i), axis=0) for i in video] + video = video_processor.preprocess_video(video, height=config.height, width=config.width) + video = jnp.array(np.array(video), dtype=config.weights_dtype) + with mesh: + latents = p_vae_encode(video=video, rng=new_rng) + encoder_hidden_states = text_encode(pipeline, text) + example = create_example(latents, encoder_hidden_states) + writer.write(example) + shard_record_count += batch_size + global_record_count += batch_size + if shard_record_count >= no_records_per_shard: + writer.close() + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + tf_rec_num +=1 + + + +def run(config): + pipeline = WanPipeline.from_pretrained(config, load_transformer=False) + # Don't need the transformer for preprocessing. + generate_dataset(config, pipeline) + + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index a3be8e138..9dff5fe25 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -221,7 +221,7 @@ def load_scheduler(cls, config): return scheduler, scheduler_state @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False): + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) @@ -232,8 +232,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False): scheduler_state = None text_encoder = None if not vae_only: - with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + if load_transformer: + with mesh: + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) diff --git a/src/maxdiffusion/schedulers/__init__.py b/src/maxdiffusion/schedulers/__init__.py index edd249de1..26743e3d7 100644 --- a/src/maxdiffusion/schedulers/__init__.py +++ b/src/maxdiffusion/schedulers/__init__.py @@ -43,7 +43,7 @@ _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"] - _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] + _import_structure["scheduling_flow_match_flax"] = ["FlaxFlowMatchScheduler"] _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"] _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] @@ -70,6 +70,7 @@ from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler + from .scheduling_flow_match_flax import FlowMatchScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py new file mode 100644 index 000000000..da447e91d --- /dev/null +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py @@ -0,0 +1,294 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This is a JAX/Flax conversion of a PyTorch implementation. +# The original PyTorch code was provided by the user. + +from typing import Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jnp + +# Assuming these are part of your project structure, similar to the UniPC example +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import ( + FlaxSchedulerMixin, + FlaxSchedulerOutput, +) + + +@flax.struct.dataclass +class FlowMatchSchedulerState: + """ + Data class to hold the mutable state of the FlaxFlowMatchScheduler. + """ + + sigmas: jnp.ndarray + timesteps: jnp.ndarray + linear_timesteps_weights: Optional[jnp.ndarray] + training: bool + num_inference_steps: int # Store for training weight calculation + + @classmethod + def create(cls): + return cls( + sigmas=None, + timesteps=None, + linear_timesteps_weights=None, + training=False, + num_inference_steps=0, + ) + + +@flax.struct.dataclass(frozen=False) +class FlaxFlowMatchSchedulerOutput(FlaxSchedulerOutput): + """ + Output class for the JAX FlowMatchScheduler's step function. + + Attributes: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The computed sample at the previous timestep. + state (`FlowMatchSchedulerState`): + The updated scheduler state. + """ + state: FlowMatchSchedulerState + + +class FlaxFlowMatchScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + FlaxFlowMatchScheduler is a JAX/Flax conversion of a scheduler used for training video generation models like + WAN 2.1. It operates based on a "flow matching" paradigm. + + This scheduler directly calculates sigmas for a continuous-time diffusion process, which can be beneficial for + certain types of models and training schemes. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 3.0, + sigma_max: float = 1.0, + sigma_min: float = 0.003 / 1.002, + inverse_timesteps: bool = False, + extra_one_step: bool = False, + reverse_sigmas: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self) -> FlowMatchSchedulerState: + """Creates the initial state for the scheduler.""" + + return FlowMatchSchedulerState.create() + + def set_timesteps( + self, + state: FlowMatchSchedulerState, + num_inference_steps: int = 100, + shape: Tuple = None, # Not used but part of the standard API + denoising_strength: float = 1.0, + training: bool = False, + shift: Optional[float] = None, + ) -> FlowMatchSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + The shape of the samples. + denoising_strength (`float`): + The strength of the denoising process. + training (`bool`): + Whether the scheduler is being used for training. + shift (`Optional[float]`): + An optional shift value to override the one in the config. + + Returns: + `FlowMatchSchedulerState`: The updated scheduler state. + """ + current_shift = shift if shift is not None else self.config.shift + sigma_start = self.config.sigma_min + (self.config.sigma_max - self.config.sigma_min) * denoising_strength + + if self.config.extra_one_step: + sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps + 1, dtype=self.dtype)[:-1] + else: + sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps, dtype=self.dtype) + + if self.config.inverse_timesteps: + sigmas = jnp.flip(sigmas, dims=[0]) + + sigmas = current_shift * sigmas / (1 + (current_shift - 1) * sigmas) + + if self.config.reverse_sigmas: + sigmas = 1 - sigmas + + timesteps = sigmas * self.config.num_train_timesteps + + linear_timesteps_weights = None + if training: + x = timesteps + y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - jnp.min(y) + bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted)) + linear_timesteps_weights = bsmntw_weighing + + return state.replace( + sigmas=sigmas, + timesteps=timesteps, + linear_timesteps_weights=linear_timesteps_weights, + training=training, + num_inference_steps=num_inference_steps, + ) + + def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: + """Finds the index of the closest timestep in the scheduler's `timesteps` array.""" + timestep = jnp.asarray(timestep, dtype=state.timesteps.dtype) + if timestep.ndim == 0: + return jnp.argmin(jnp.abs(state.timesteps - timestep)) + else: + diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None]) + return jnp.argmin(diffs, axis=1) + + def step( + self, + state: FlowMatchSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + to_final: bool = False, + return_dict: bool = True, + ) -> Union[FlaxFlowMatchSchedulerOutput, Tuple]: + """ + Propagates the sample with the flow matching scheduler. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + model_output (`jnp.ndarray`): + The direct output from the learned diffusion model. + timestep (`jnp.ndarray`): + The current timestep in the diffusion chain. + sample (`jnp.ndarray`): + The current sample (e.g. noisy latents). + to_final (`bool`): + Whether this is the final step. + return_dict (`bool`): + Whether to return a `FlaxFlowMatchSchedulerOutput` object. + + Returns: + `FlaxFlowMatchSchedulerOutput` or `tuple`: A tuple (`prev_sample`, `state`) or a + `FlaxFlowMatchSchedulerOutput` object containing the previous sample and the updated state. + """ + timestep_id = self._find_timestep_id(state, timestep) + sigma = state.sigmas[timestep_id] + + def get_next_sigma(): + return state.sigmas[timestep_id + 1] + + def get_final_sigma(): + return jnp.array(1.0 if (self.config.inverse_timesteps or self.config.reverse_sigmas) else 0.0, dtype=sigma.dtype) + + is_final_step = to_final or jnp.all(timestep_id + 1 >= state.timesteps.shape[0]) + sigma_next = jax.lax.cond(is_final_step, get_final_sigma, get_next_sigma) + + if jnp.ndim(timestep) != 0: + broadcast_shape = (-1,) + (1,) * (sample.ndim - 1) + sigma = sigma.reshape(broadcast_shape) + sigma_next = sigma_next.reshape(broadcast_shape) + + prev_sample = sample + model_output * (sigma_next - sigma) + + if not return_dict: + return (prev_sample, state) + + return FlaxFlowMatchSchedulerOutput(prev_sample=prev_sample, state=state) + + def return_to_timestep( + self, state: FlowMatchSchedulerState, timestep: jnp.ndarray, sample: jnp.ndarray, sample_stablized: jnp.ndarray + ) -> jnp.ndarray: + """Calculates the model output required to go from a stabilized sample back to the original sample.""" + timestep_id = self._find_timestep_id(state, timestep) + sigma = state.sigmas[timestep_id] + + if jnp.ndim(timestep) != 0: + sigma = sigma.reshape((-1,) + (1,) * (sample.ndim - 1)) + + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise( + self, + state: FlowMatchSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + """ + Adds noise to the original samples according to the flow matching schedule. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + original_samples (`jnp.ndarray`): + The original clean samples. + noise (`jnp.ndarray`): + The noise to add to the samples. + timesteps (`jnp.ndarray`): + The timesteps that correspond to the noise levels. + + Returns: + `jnp.ndarray`: The noisy samples. + """ + if state.sigmas is None or state.timesteps is None: + raise ValueError( + "Scheduler's `sigmas` and `timesteps` are not set. Please call `set_timesteps` before `add_noise`." + ) + + timestep_ids = self._find_timestep_id(state, timesteps) + sigmas = state.sigmas[timestep_ids] + + broadcast_shape = (-1,) + (1,) * (original_samples.ndim - 1) + sigmas = sigmas.reshape(broadcast_shape) + + noisy_samples = (1 - sigmas) * original_samples + sigmas * noise + return noisy_samples + + def training_target(self, sample: jnp.ndarray, noise: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: + """ + Calculates the training target. For flow matching, this is typically the velocity, `x_1 - x_0`, + which is equivalent to `noise - sample` under this scheduler's `add_noise` definition. + """ + target = noise - sample + return target + + def training_weight(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: + """Calculates the training weight for a given timestep.""" + timestep_ids = self._find_timestep_id(state, timestep) + weights = state.linear_timesteps_weights[timestep_ids] + return weights + + def __len__(self) -> int: + return self.config.num_train_timesteps \ No newline at end of file diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index b11626c27..ef2b243cd 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -21,7 +21,8 @@ import jax.numpy as jnp import jax import jax.tree_util as jtu -from flax import nnx +from flax import nnx +from ..schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning from ..schedulers import FlaxEulerDiscreteScheduler from .. import max_utils, max_logging, train_utils, maxdiffusion_utils @@ -41,14 +42,11 @@ def __init__(self, config): def post_training_steps(self, pipeline, params, train_states, msg=""): pass - def create_scheduler(self, pipeline, params): - # TODO - set right scheduler - noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( - pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 - ) - noise_scheduler_state = noise_scheduler.set_timesteps( - state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" - ) + def create_scheduler(self): + """Creates and initializes the Flow Match scheduler for training.""" + noise_scheduler = FlaxFlowMatchScheduler(dtype=jnp.float32) + noise_scheduler_state = noise_scheduler.create_state() + noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True) return noise_scheduler, noise_scheduler_state def calculate_tflops(self, pipeline): @@ -71,7 +69,14 @@ def start_training(self): pipeline = self.load_checkpoint() del pipeline.vae dummy_inputs = self.load_dataset(pipeline) + mesh = pipeline.mesh + + # Load FlowMatch scheduler + scheduler, scheduler_state = self.create_scheduler() + pipeline.scheduler = scheduler + pipeline.scheduler_state = scheduler_state + optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) dummy_inputs = tuple( [jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs] @@ -95,7 +100,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): state = state.to_pure_dict() p_train_step = jax.jit( - train_step, + functools.partial(train_step, scheduler=pipeline.scheduler), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -113,13 +118,15 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): start_step = 0 per_device_tflops = self.calculate_tflops(pipeline) + scheduler_state = pipeline.scheduler_state + for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( self.config.logical_axis_rules ): - state, train_metric, rng = p_train_step(state, graphdef, data, rng) + state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, data, rng) new_time = datetime.datetime.now() @@ -134,11 +141,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): last_step_completion = new_time -def train_step(state, graphdef, data, rng): - return step_optimizer(graphdef, state, data, rng) +def train_step(state, graphdef, scheduler_state, data, rng, scheduler): + return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng) -def step_optimizer(graphdef, state, data, rng): +def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng): _, new_rng = jax.random.split(rng) def loss_fn(model): @@ -147,18 +154,20 @@ def loss_fn(model): noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) # TODO - add noise here + noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) model_pred = model( - hidden_states=noise, + hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeds, is_uncond=jnp.array(False, dtype=jnp.bool_), slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) - target = noise - latents - loss = (target - model_pred) ** 2 + + training_target = scheduler.training_target(latents, noise, timesteps) + loss = ((training_target - model_pred) ** 2) * scheduler.training_weight(scheduler_state, timesteps) loss = jnp.mean(loss) - # breakpoint() + return loss model, optimizer = nnx.merge(graphdef, state) @@ -167,4 +176,4 @@ def loss_fn(model): state = nnx.state((model, optimizer)) state = state.to_pure_dict() metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - return state, metrics, new_rng + return state, scheduler_state, metrics, new_rng From afc888288de0c84a37a2bad67e99027512da3480 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 26 Jun 2025 01:12:12 +0000 Subject: [PATCH 2/4] training pipeline with image dataset. --- src/maxdiffusion/configs/base_wan_14b.yml | 5 +- .../wan_txt2vid_data_preprocessing.py | 19 +-- src/maxdiffusion/generate_wan.py | 7 +- .../input_pipeline/_tfds_data_processing.py | 32 ++-- .../input_pipeline_interface.py | 12 ++ src/maxdiffusion/trainers/sdxl_trainer.py | 18 +++ src/maxdiffusion/trainers/wan_trainer.py | 137 +++++++++++++----- 7 files changed, 162 insertions(+), 68 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 073c90638..9ef1b72d6 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -141,14 +141,15 @@ ici_tensor_parallelism: 1 # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tf' +dataset_type: 'tfrecord' cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory # prepare image latents and text encoder outputs # Reduce memory consumption and reduce step time during training # transformed dataset is saved at dataset_save_location -dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +dataset_save_location: '' +load_tfrecord_cached: True train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' diff --git a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py index 19e4b52f1..c3fc64b1b 100644 --- a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py +++ b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py @@ -112,25 +112,26 @@ def generate_dataset(config, pipeline): for i in range(0, len(ds), batch_size): rng, new_rng = jax.random.split(rng) text = ds[i:i+batch_size]['text'] - video = ds[i:i+batch_size]['image'] + videos = ds[i:i+batch_size]['image'] - video = [np.expand_dims(np.array(i), axis=0) for i in video] - video = video_processor.preprocess_video(video, height=config.height, width=config.width) - video = jnp.array(np.array(video), dtype=config.weights_dtype) + videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos] + video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype) with mesh: latents = p_vae_encode(video=video, rng=new_rng) + latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) encoder_hidden_states = text_encode(pipeline, text) - example = create_example(latents, encoder_hidden_states) - writer.write(example) - shard_record_count += batch_size - global_record_count += batch_size + for latent, encoder_hidden_state in zip(latents, encoder_hidden_states): + writer.write(create_example(latent, encoder_hidden_state)) + shard_record_count += 1 + global_record_count += 1 + if shard_record_count >= no_records_per_shard: writer.close() + tf_rec_num +=1 writer = tf.io.TFRecordWriter( tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) ) shard_record_count = 0 - tf_rec_num +=1 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 760d655cc..1a6dcb3b9 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -21,9 +21,10 @@ from maxdiffusion.utils import export_to_video -def run(config): +def run(config, pipeline=None, filename_prefix=''): print("seed: ", config.seed) - pipeline = WanPipeline.from_pretrained(config) + if pipeline is None: + pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() # Skip layer guidance @@ -59,7 +60,7 @@ def run(config): print("compile time: ", (time.perf_counter() - s0)) for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps) s0 = time.perf_counter() videos = pipeline( prompt=prompt, diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 6b588ed2d..34b5435c4 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -79,29 +79,17 @@ def make_cached_tfrecord_iterator( dataloading_host_count, mesh, global_batch_size, + feature_description, + prepare_sample_fn ): """ New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings: latents, input_ids, prompt_embeds, and text_embeds. """ - feature_description = { - "pixel_values": tf.io.FixedLenFeature([], tf.string), - "input_ids": tf.io.FixedLenFeature([], tf.string), - "prompt_embeds": tf.io.FixedLenFeature([], tf.string), - "text_embeds": tf.io.FixedLenFeature([], tf.string), - } def _parse_tfrecord_fn(example): return tf.io.parse_single_example(example, feature_description) - def prepare_sample(features): - pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32) - input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32) - prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) - text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) - - return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds} - # This pipeline reads the sharded files and applies the parsing and preparation. filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) @@ -109,7 +97,7 @@ def prepare_sample(features): tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) .shard(num_shards=dataloading_host_count, index=dataloading_host_index) .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(prepare_sample, num_parallel_calls=AUTOTUNE) + .map(prepare_sample_fn, num_parallel_calls=AUTOTUNE) .shuffle(global_batch_size * 10) .batch(global_batch_size // dataloading_host_count, drop_remainder=True) .repeat(-1) @@ -128,6 +116,8 @@ def make_tfrecord_iterator( dataloading_host_count, mesh, global_batch_size, + feature_description, + prepare_sample_fn ): """Iterator for TFRecord format. For Laion dataset, check out preparation script @@ -136,12 +126,20 @@ def make_tfrecord_iterator( # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. - # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. if (config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location) and 'load_tfrecord_cached'in config.get_keys() and config.load_tfrecord_cached): - return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size) + return make_cached_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn + ) feature_description = { "moments": tf.io.FixedLenFeature([], tf.string), diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 3a78ff09b..e940aca30 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -50,8 +50,18 @@ def make_data_iterator( global_batch_size, tokenize_fn=None, image_transforms_fn=None, + feature_description=None, + prepare_sample_fn=None ): """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" + + if config.dataset_type == "hf" or config.dataset_type == "tf": + if tokenize_fn is None or image_transforms_fn is None: + raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn") + + if config.dataset_type == "tfrecord" and config.cache_latents_text_encoder_outputs and feature_description is None or prepare_sample_fn is None: + raise ValueError(f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True.") + if config.dataset_type == "hf": return _hf_data_processing.make_hf_streaming_iterator( config, @@ -87,6 +97,8 @@ def make_data_iterator( dataloading_host_count, mesh, global_batch_size, + feature_description, + prepare_sample_fn ) else: assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 4cc81955b..3dcaa57d0 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -20,6 +20,7 @@ import threading import time import numpy as np +import tensorflow as tf import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -140,6 +141,21 @@ def load_dataset(self, pipeline, params, train_states): p_vae_apply=p_vae_apply, ) + feature_description = { + "pixel_values": tf.io.FixedLenFeature([], tf.string), + "input_ids": tf.io.FixedLenFeature([], tf.string), + "prompt_embeds": tf.io.FixedLenFeature([], tf.string), + "text_embeds": tf.io.FixedLenFeature([], tf.string), + } + + def prepare_sample(features): + pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32) + input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32) + prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) + text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) + + return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds} + data_iterator = make_data_iterator( config, jax.process_index(), @@ -148,6 +164,8 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size, tokenize_fn=tokenize_fn, image_transforms_fn=image_transforms_fn, + feature_description=feature_description, + prepare_sample_fn=prepare_sample ) return data_iterator diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index ef2b243cd..c9dc4f34d 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -18,6 +18,9 @@ import datetime import functools import numpy as np +import threading +from concurrent.futures import ThreadPoolExecutor +import tensorflow as tf import jax.numpy as jnp import jax import jax.tree_util as jtu @@ -28,6 +31,19 @@ from .. import max_utils, max_logging, train_utils, maxdiffusion_utils from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) from maxdiffusion.multihost_dataloading import _form_global_array +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) +from maxdiffusion.generate_wan import run as generate_wan +from maxdiffusion.train_utils import ( + _tensorboard_writer_worker, + load_next_batch, + _metrics_queue +) + +def generate_sample(config, pipeline, filename_prefix): + """ + Generates a video to validate training did not corrupt the model + """ + generate_wan(config, pipeline, filename_prefix) class WanTrainer(WanCheckpointer): @@ -53,7 +69,7 @@ def calculate_tflops(self, pipeline): max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") return 0 - def load_dataset(self, pipeline): + def load_dataset(self, mesh): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 @@ -62,15 +78,40 @@ def load_dataset(self, pipeline): # prompt embeds shape: (1, 512, 4096) # For now, we will pass the same latents over and over # TODO - create a dataset - return maxdiffusion_utils.get_dummy_wan_inputs(self.config, pipeline, self.global_batch_size) + config = self.config + if config.dataset_type != "tfrecord" and not config.cache_latents_text_encoder_outputs: + raise ValueError("Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True") + + feature_description = { + "latents" : tf.io.FixedLenFeature([], tf.string), + "encoder_hidden_states" : tf.io.FixedLenFeature([], tf.string), + } + + def prepare_sample(features): + latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) + encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) + return {"latents" : latents, "encoder_hidden_states" : encoder_hidden_states} + + data_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + self.global_batch_size, + feature_description=feature_description, + prepare_sample_fn=prepare_sample + ) + return data_iterator def start_training(self): pipeline = self.load_checkpoint() - del pipeline.vae - dummy_inputs = self.load_dataset(pipeline) - + #del pipeline.vae + + # Generate a sample before training to compare against generated sample after training. + generate_sample(self.config, pipeline, filename_prefix='pre-training-') mesh = pipeline.mesh + data_iterator = self.load_dataset(mesh) # Load FlowMatch scheduler scheduler, scheduler_state = self.create_scheduler() @@ -78,15 +119,16 @@ def start_training(self): pipeline.scheduler_state = scheduler_state optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) - dummy_inputs = tuple( - [jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs] - ) - self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs) + self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): graphdef, state = nnx.split((pipeline.transformer, optimizer)) + writer = max_utils.initialize_summary_writer(self.config) + writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) + writer_thread.start() + num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) @@ -119,26 +161,40 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data): per_device_tflops = self.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state - - for step in np.arange(start_step, self.config.max_train_steps): - if self.config.enable_profiler and step == first_profiling_step: - max_utils.activate_profiler(self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): - state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, data, rng) - - new_time = datetime.datetime.now() - - if self.config.enable_profiler and step == last_profiling_step: - max_utils.deactivate_profiler(self.config) - - train_utils.record_scalar_metrics( - train_metric, new_time - last_step_completion, per_device_tflops, learning_rate_scheduler(step) - ) - if self.config.write_metrics: - train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) - last_step_completion = new_time + example_batch = load_next_batch(data_iterator, None, self.config) + with ThreadPoolExecutor(max_workers=1) as executor: + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + start_step_time = datetime.datetime.now() + next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config) + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): + state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, example_batch, rng) + train_metric["scalar"]["learning/loss"].block_until_ready() + last_step_completion = datetime.datetime.now() + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + train_utils.record_scalar_metrics( + train_metric, last_step_completion - start_step_time, per_device_tflops, learning_rate_scheduler(step) + ) + if self.config.write_metrics: + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + example_batch = next_batch_future.result() + + _metrics_queue.put(None) + writer_thread.join() + if writer: + writer.flush() + + # load new state for trained tranformer + graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) + pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state) + + generate_sample(self.config, pipeline, filename_prefix='post-training-') def train_step(state, graphdef, scheduler_state, data, rng, scheduler): @@ -146,26 +202,33 @@ def train_step(state, graphdef, scheduler_state, data, rng, scheduler): def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng): - _, new_rng = jax.random.split(rng) + _, new_rng, timestep_rng = jax.random.split(rng, num=3) def loss_fn(model): - latents, prompt_embeds, timesteps = data - + latents = data['latents'] + encoder_hidden_states = data['encoder_hidden_states'] + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + scheduler.config.num_train_timesteps, + ) noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) - - # TODO - add noise here noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) model_pred = model( hidden_states=noisy_latents, timestep=timesteps, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, is_uncond=jnp.array(False, dtype=jnp.bool_), slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) training_target = scheduler.training_target(latents, noise, timesteps) - loss = ((training_target - model_pred) ** 2) * scheduler.training_weight(scheduler_state, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = ((training_target - model_pred) ** 2) + loss = loss * training_weight loss = jnp.mean(loss) return loss From a6bc42b067ede593056fb23f80e1958731e26493 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 26 Jun 2025 14:25:02 +0000 Subject: [PATCH 3/4] lint --- .../data_preprocessing/__init__.py | 2 +- .../wan_txt2vid_data_preprocessing.py | 27 +- src/maxdiffusion/generate_wan.py | 2 +- .../input_pipeline/_tfds_data_processing.py | 39 +- .../input_pipeline_interface.py | 19 +- .../pipelines/wan/wan_pipeline.py | 3 +- .../schedulers/scheduling_flow_match_flax.py | 491 +++++++++--------- src/maxdiffusion/trainers/sdxl_trainer.py | 17 +- src/maxdiffusion/trainers/wan_trainer.py | 70 ++- 9 files changed, 336 insertions(+), 334 deletions(-) diff --git a/src/maxdiffusion/data_preprocessing/__init__.py b/src/maxdiffusion/data_preprocessing/__init__.py index 55bca151a..7e4185f36 100644 --- a/src/maxdiffusion/data_preprocessing/__init__.py +++ b/src/maxdiffusion/data_preprocessing/__init__.py @@ -12,4 +12,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - """ \ No newline at end of file + """ diff --git a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py index c3fc64b1b..ae0b15f47 100644 --- a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py +++ b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py @@ -34,6 +34,7 @@ import tensorflow as tf + def image_feature(value): """Returns a bytes_list from a string / byte.""" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) @@ -58,6 +59,7 @@ def float_feature_list(value): """Returns a list of float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + def create_example(latent, hidden_states): latent = tf.io.serialize_tensor(latent) hidden_states = tf.io.serialize_tensor(hidden_states) @@ -74,11 +76,13 @@ def text_encode(pipeline, prompt: Union[str, List[str]]): encoder_hidden_states = encoder_hidden_states.detach().numpy() return encoder_hidden_states + def vae_encode(video, rng, vae, vae_cache): latent = vae.encode(video, feat_cache=vae_cache) latent = latent.latent_dist.sample(rng) return latent - + + def generate_dataset(config, pipeline): tfrecords_dir = config.tfrecords_dir @@ -99,21 +103,21 @@ def generate_dataset(config, pipeline): rng = jax.random.key(config.seed) vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample) - video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) - + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial) + # jit vae fun. p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache)) - + # Load dataset - ds = load_dataset(config.dataset_name, split='train') + ds = load_dataset(config.dataset_name, split="train") ds = ds.shuffle(seed=config.seed) ds = ds.select_columns([config.caption_column, config.image_column]) batch_size = 10 for i in range(0, len(ds), batch_size): rng, new_rng = jax.random.split(rng) - text = ds[i:i+batch_size]['text'] - videos = ds[i:i+batch_size]['image'] - + text = ds[i : i + batch_size]["text"] + videos = ds[i : i + batch_size]["image"] + videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos] video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype) with mesh: @@ -127,24 +131,23 @@ def generate_dataset(config, pipeline): if shard_record_count >= no_records_per_shard: writer.close() - tf_rec_num +=1 + tf_rec_num += 1 writer = tf.io.TFRecordWriter( tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) ) shard_record_count = 0 - def run(config): pipeline = WanPipeline.from_pretrained(config, load_transformer=False) # Don't need the transformer for preprocessing. generate_dataset(config, pipeline) - def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) + if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 1a6dcb3b9..8486c79d5 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -21,7 +21,7 @@ from maxdiffusion.utils import export_to_video -def run(config, pipeline=None, filename_prefix=''): +def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) if pipeline is None: pipeline = WanPipeline.from_pretrained(config) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 34b5435c4..454f65785 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -73,14 +73,9 @@ def make_tf_iterator( train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter + def make_cached_tfrecord_iterator( - config, - dataloading_host_index, - dataloading_host_count, - mesh, - global_batch_size, - feature_description, - prepare_sample_fn + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn ): """ New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings: @@ -111,13 +106,7 @@ def _parse_tfrecord_fn(example): # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py def make_tfrecord_iterator( - config, - dataloading_host_index, - dataloading_host_count, - mesh, - global_batch_size, - feature_description, - prepare_sample_fn + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn ): """Iterator for TFRecord format. For Laion dataset, check out preparation script @@ -127,18 +116,20 @@ def make_tfrecord_iterator( # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. - if (config.cache_latents_text_encoder_outputs + if ( + config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location) - and 'load_tfrecord_cached'in config.get_keys() - and config.load_tfrecord_cached): + and "load_tfrecord_cached" in config.get_keys() + and config.load_tfrecord_cached + ): return make_cached_tfrecord_iterator( - config, - dataloading_host_index, - dataloading_host_count, - mesh, - global_batch_size, - feature_description, - prepare_sample_fn + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn, ) feature_description = { diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index e940aca30..d0be27144 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -51,16 +51,23 @@ def make_data_iterator( tokenize_fn=None, image_transforms_fn=None, feature_description=None, - prepare_sample_fn=None + prepare_sample_fn=None, ): """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" - + if config.dataset_type == "hf" or config.dataset_type == "tf": if tokenize_fn is None or image_transforms_fn is None: raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn") - - if config.dataset_type == "tfrecord" and config.cache_latents_text_encoder_outputs and feature_description is None or prepare_sample_fn is None: - raise ValueError(f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True.") + + if ( + config.dataset_type == "tfrecord" + and config.cache_latents_text_encoder_outputs + and feature_description is None + or prepare_sample_fn is None + ): + raise ValueError( + f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True." + ) if config.dataset_type == "hf": return _hf_data_processing.make_hf_streaming_iterator( @@ -98,7 +105,7 @@ def make_data_iterator( mesh, global_batch_size, feature_description, - prepare_sample_fn + prepare_sample_fn, ) else: assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9dff5fe25..db4b25fb2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -41,6 +41,7 @@ def basic_clean(text): if is_ftfy_available(): import ftfy + text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() @@ -398,7 +399,7 @@ def __call__( num_channels_latents=num_channel_latents, ) - data_sharding = NamedSharding(self.devices_array, P()) + data_sharding = NamedSharding(self.mesh, P()) if len(prompt) % jax.device_count() == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py index da447e91d..1f9c3a78e 100644 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py @@ -31,264 +31,263 @@ @flax.struct.dataclass class FlowMatchSchedulerState: + """ + Data class to hold the mutable state of the FlaxFlowMatchScheduler. + """ + + sigmas: jnp.ndarray + timesteps: jnp.ndarray + linear_timesteps_weights: Optional[jnp.ndarray] + training: bool + num_inference_steps: int # Store for training weight calculation + + @classmethod + def create(cls): + return cls( + sigmas=None, + timesteps=None, + linear_timesteps_weights=None, + training=False, + num_inference_steps=0, + ) + + +@flax.struct.dataclass(frozen=False) +class FlaxFlowMatchSchedulerOutput(FlaxSchedulerOutput): + """ + Output class for the JAX FlowMatchScheduler's step function. + + Attributes: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The computed sample at the previous timestep. + state (`FlowMatchSchedulerState`): + The updated scheduler state. + """ + + state: FlowMatchSchedulerState + + +class FlaxFlowMatchScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + FlaxFlowMatchScheduler is a JAX/Flax conversion of a scheduler used for training video generation models like + WAN 2.1. It operates based on a "flow matching" paradigm. + + This scheduler directly calculates sigmas for a continuous-time diffusion process, which can be beneficial for + certain types of models and training schemes. + """ + + dtype: jnp.dtype + + @property + def has_state(self) -> bool: + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 3.0, + sigma_max: float = 1.0, + sigma_min: float = 0.003 / 1.002, + inverse_timesteps: bool = False, + extra_one_step: bool = False, + reverse_sigmas: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.dtype = dtype + + def create_state(self) -> FlowMatchSchedulerState: + """Creates the initial state for the scheduler.""" + + return FlowMatchSchedulerState.create() + + def set_timesteps( + self, + state: FlowMatchSchedulerState, + num_inference_steps: int = 100, + shape: Tuple = None, # Not used but part of the standard API + denoising_strength: float = 1.0, + training: bool = False, + shift: Optional[float] = None, + ) -> FlowMatchSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. + + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + The shape of the samples. + denoising_strength (`float`): + The strength of the denoising process. + training (`bool`): + Whether the scheduler is being used for training. + shift (`Optional[float]`): + An optional shift value to override the one in the config. + + Returns: + `FlowMatchSchedulerState`: The updated scheduler state. """ - Data class to hold the mutable state of the FlaxFlowMatchScheduler. + current_shift = shift if shift is not None else self.config.shift + sigma_start = self.config.sigma_min + (self.config.sigma_max - self.config.sigma_min) * denoising_strength + + if self.config.extra_one_step: + sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps + 1, dtype=self.dtype)[:-1] + else: + sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps, dtype=self.dtype) + + if self.config.inverse_timesteps: + sigmas = jnp.flip(sigmas, dims=[0]) + + sigmas = current_shift * sigmas / (1 + (current_shift - 1) * sigmas) + + if self.config.reverse_sigmas: + sigmas = 1 - sigmas + + timesteps = sigmas * self.config.num_train_timesteps + + linear_timesteps_weights = None + if training: + x = timesteps + y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - jnp.min(y) + bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted)) + linear_timesteps_weights = bsmntw_weighing + + return state.replace( + sigmas=sigmas, + timesteps=timesteps, + linear_timesteps_weights=linear_timesteps_weights, + training=training, + num_inference_steps=num_inference_steps, + ) + + def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: + """Finds the index of the closest timestep in the scheduler's `timesteps` array.""" + timestep = jnp.asarray(timestep, dtype=state.timesteps.dtype) + if timestep.ndim == 0: + return jnp.argmin(jnp.abs(state.timesteps - timestep)) + else: + diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None]) + return jnp.argmin(diffs, axis=1) + + def step( + self, + state: FlowMatchSchedulerState, + model_output: jnp.ndarray, + timestep: jnp.ndarray, + sample: jnp.ndarray, + to_final: bool = False, + return_dict: bool = True, + ) -> Union[FlaxFlowMatchSchedulerOutput, Tuple]: """ + Propagates the sample with the flow matching scheduler. - sigmas: jnp.ndarray - timesteps: jnp.ndarray - linear_timesteps_weights: Optional[jnp.ndarray] - training: bool - num_inference_steps: int # Store for training weight calculation + Args: + state (`FlowMatchSchedulerState`): + The current scheduler state. + model_output (`jnp.ndarray`): + The direct output from the learned diffusion model. + timestep (`jnp.ndarray`): + The current timestep in the diffusion chain. + sample (`jnp.ndarray`): + The current sample (e.g. noisy latents). + to_final (`bool`): + Whether this is the final step. + return_dict (`bool`): + Whether to return a `FlaxFlowMatchSchedulerOutput` object. + + Returns: + `FlaxFlowMatchSchedulerOutput` or `tuple`: A tuple (`prev_sample`, `state`) or a + `FlaxFlowMatchSchedulerOutput` object containing the previous sample and the updated state. + """ + timestep_id = self._find_timestep_id(state, timestep) + sigma = state.sigmas[timestep_id] - @classmethod - def create(cls): - return cls( - sigmas=None, - timesteps=None, - linear_timesteps_weights=None, - training=False, - num_inference_steps=0, - ) + def get_next_sigma(): + return state.sigmas[timestep_id + 1] + def get_final_sigma(): + return jnp.array(1.0 if (self.config.inverse_timesteps or self.config.reverse_sigmas) else 0.0, dtype=sigma.dtype) -@flax.struct.dataclass(frozen=False) -class FlaxFlowMatchSchedulerOutput(FlaxSchedulerOutput): + is_final_step = to_final or jnp.all(timestep_id + 1 >= state.timesteps.shape[0]) + sigma_next = jax.lax.cond(is_final_step, get_final_sigma, get_next_sigma) + + if jnp.ndim(timestep) != 0: + broadcast_shape = (-1,) + (1,) * (sample.ndim - 1) + sigma = sigma.reshape(broadcast_shape) + sigma_next = sigma_next.reshape(broadcast_shape) + + prev_sample = sample + model_output * (sigma_next - sigma) + + if not return_dict: + return (prev_sample, state) + + return FlaxFlowMatchSchedulerOutput(prev_sample=prev_sample, state=state) + + def return_to_timestep( + self, state: FlowMatchSchedulerState, timestep: jnp.ndarray, sample: jnp.ndarray, sample_stablized: jnp.ndarray + ) -> jnp.ndarray: + """Calculates the model output required to go from a stabilized sample back to the original sample.""" + timestep_id = self._find_timestep_id(state, timestep) + sigma = state.sigmas[timestep_id] + + if jnp.ndim(timestep) != 0: + sigma = sigma.reshape((-1,) + (1,) * (sample.ndim - 1)) + + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise( + self, + state: FlowMatchSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: """ - Output class for the JAX FlowMatchScheduler's step function. + Adds noise to the original samples according to the flow matching schedule. - Attributes: - prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - The computed sample at the previous timestep. + Args: state (`FlowMatchSchedulerState`): - The updated scheduler state. + The current scheduler state. + original_samples (`jnp.ndarray`): + The original clean samples. + noise (`jnp.ndarray`): + The noise to add to the samples. + timesteps (`jnp.ndarray`): + The timesteps that correspond to the noise levels. + + Returns: + `jnp.ndarray`: The noisy samples. """ - state: FlowMatchSchedulerState + if state.sigmas is None or state.timesteps is None: + raise ValueError("Scheduler's `sigmas` and `timesteps` are not set. Please call `set_timesteps` before `add_noise`.") + timestep_ids = self._find_timestep_id(state, timesteps) + sigmas = state.sigmas[timestep_ids] -class FlaxFlowMatchScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - FlaxFlowMatchScheduler is a JAX/Flax conversion of a scheduler used for training video generation models like - WAN 2.1. It operates based on a "flow matching" paradigm. + broadcast_shape = (-1,) + (1,) * (original_samples.ndim - 1) + sigmas = sigmas.reshape(broadcast_shape) - This scheduler directly calculates sigmas for a continuous-time diffusion process, which can be beneficial for - certain types of models and training schemes. + noisy_samples = (1 - sigmas) * original_samples + sigmas * noise + return noisy_samples + + def training_target(self, sample: jnp.ndarray, noise: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: + """ + Calculates the training target. For flow matching, this is typically the velocity, `x_1 - x_0`, + which is equivalent to `noise - sample` under this scheduler's `add_noise` definition. """ + target = noise - sample + return target + + def training_weight(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: + """Calculates the training weight for a given timestep.""" + timestep_ids = self._find_timestep_id(state, timestep) + weights = state.linear_timesteps_weights[timestep_ids] + return weights - dtype: jnp.dtype - - @property - def has_state(self) -> bool: - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - shift: float = 3.0, - sigma_max: float = 1.0, - sigma_min: float = 0.003 / 1.002, - inverse_timesteps: bool = False, - extra_one_step: bool = False, - reverse_sigmas: bool = False, - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self) -> FlowMatchSchedulerState: - """Creates the initial state for the scheduler.""" - - return FlowMatchSchedulerState.create() - - def set_timesteps( - self, - state: FlowMatchSchedulerState, - num_inference_steps: int = 100, - shape: Tuple = None, # Not used but part of the standard API - denoising_strength: float = 1.0, - training: bool = False, - shift: Optional[float] = None, - ) -> FlowMatchSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain. - - Args: - state (`FlowMatchSchedulerState`): - The current scheduler state. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - shape (`Tuple`): - The shape of the samples. - denoising_strength (`float`): - The strength of the denoising process. - training (`bool`): - Whether the scheduler is being used for training. - shift (`Optional[float]`): - An optional shift value to override the one in the config. - - Returns: - `FlowMatchSchedulerState`: The updated scheduler state. - """ - current_shift = shift if shift is not None else self.config.shift - sigma_start = self.config.sigma_min + (self.config.sigma_max - self.config.sigma_min) * denoising_strength - - if self.config.extra_one_step: - sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps + 1, dtype=self.dtype)[:-1] - else: - sigmas = jnp.linspace(sigma_start, self.config.sigma_min, num_inference_steps, dtype=self.dtype) - - if self.config.inverse_timesteps: - sigmas = jnp.flip(sigmas, dims=[0]) - - sigmas = current_shift * sigmas / (1 + (current_shift - 1) * sigmas) - - if self.config.reverse_sigmas: - sigmas = 1 - sigmas - - timesteps = sigmas * self.config.num_train_timesteps - - linear_timesteps_weights = None - if training: - x = timesteps - y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) - y_shifted = y - jnp.min(y) - bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted)) - linear_timesteps_weights = bsmntw_weighing - - return state.replace( - sigmas=sigmas, - timesteps=timesteps, - linear_timesteps_weights=linear_timesteps_weights, - training=training, - num_inference_steps=num_inference_steps, - ) - - def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: - """Finds the index of the closest timestep in the scheduler's `timesteps` array.""" - timestep = jnp.asarray(timestep, dtype=state.timesteps.dtype) - if timestep.ndim == 0: - return jnp.argmin(jnp.abs(state.timesteps - timestep)) - else: - diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None]) - return jnp.argmin(diffs, axis=1) - - def step( - self, - state: FlowMatchSchedulerState, - model_output: jnp.ndarray, - timestep: jnp.ndarray, - sample: jnp.ndarray, - to_final: bool = False, - return_dict: bool = True, - ) -> Union[FlaxFlowMatchSchedulerOutput, Tuple]: - """ - Propagates the sample with the flow matching scheduler. - - Args: - state (`FlowMatchSchedulerState`): - The current scheduler state. - model_output (`jnp.ndarray`): - The direct output from the learned diffusion model. - timestep (`jnp.ndarray`): - The current timestep in the diffusion chain. - sample (`jnp.ndarray`): - The current sample (e.g. noisy latents). - to_final (`bool`): - Whether this is the final step. - return_dict (`bool`): - Whether to return a `FlaxFlowMatchSchedulerOutput` object. - - Returns: - `FlaxFlowMatchSchedulerOutput` or `tuple`: A tuple (`prev_sample`, `state`) or a - `FlaxFlowMatchSchedulerOutput` object containing the previous sample and the updated state. - """ - timestep_id = self._find_timestep_id(state, timestep) - sigma = state.sigmas[timestep_id] - - def get_next_sigma(): - return state.sigmas[timestep_id + 1] - - def get_final_sigma(): - return jnp.array(1.0 if (self.config.inverse_timesteps or self.config.reverse_sigmas) else 0.0, dtype=sigma.dtype) - - is_final_step = to_final or jnp.all(timestep_id + 1 >= state.timesteps.shape[0]) - sigma_next = jax.lax.cond(is_final_step, get_final_sigma, get_next_sigma) - - if jnp.ndim(timestep) != 0: - broadcast_shape = (-1,) + (1,) * (sample.ndim - 1) - sigma = sigma.reshape(broadcast_shape) - sigma_next = sigma_next.reshape(broadcast_shape) - - prev_sample = sample + model_output * (sigma_next - sigma) - - if not return_dict: - return (prev_sample, state) - - return FlaxFlowMatchSchedulerOutput(prev_sample=prev_sample, state=state) - - def return_to_timestep( - self, state: FlowMatchSchedulerState, timestep: jnp.ndarray, sample: jnp.ndarray, sample_stablized: jnp.ndarray - ) -> jnp.ndarray: - """Calculates the model output required to go from a stabilized sample back to the original sample.""" - timestep_id = self._find_timestep_id(state, timestep) - sigma = state.sigmas[timestep_id] - - if jnp.ndim(timestep) != 0: - sigma = sigma.reshape((-1,) + (1,) * (sample.ndim - 1)) - - model_output = (sample - sample_stablized) / sigma - return model_output - - def add_noise( - self, - state: FlowMatchSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - """ - Adds noise to the original samples according to the flow matching schedule. - - Args: - state (`FlowMatchSchedulerState`): - The current scheduler state. - original_samples (`jnp.ndarray`): - The original clean samples. - noise (`jnp.ndarray`): - The noise to add to the samples. - timesteps (`jnp.ndarray`): - The timesteps that correspond to the noise levels. - - Returns: - `jnp.ndarray`: The noisy samples. - """ - if state.sigmas is None or state.timesteps is None: - raise ValueError( - "Scheduler's `sigmas` and `timesteps` are not set. Please call `set_timesteps` before `add_noise`." - ) - - timestep_ids = self._find_timestep_id(state, timesteps) - sigmas = state.sigmas[timestep_ids] - - broadcast_shape = (-1,) + (1,) * (original_samples.ndim - 1) - sigmas = sigmas.reshape(broadcast_shape) - - noisy_samples = (1 - sigmas) * original_samples + sigmas * noise - return noisy_samples - - def training_target(self, sample: jnp.ndarray, noise: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: - """ - Calculates the training target. For flow matching, this is typically the velocity, `x_1 - x_0`, - which is equivalent to `noise - sample` under this scheduler's `add_noise` definition. - """ - target = noise - sample - return target - - def training_weight(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: - """Calculates the training weight for a given timestep.""" - timestep_ids = self._find_timestep_id(state, timestep) - weights = state.linear_timesteps_weights[timestep_ids] - return weights - - def __len__(self) -> int: - return self.config.num_train_timesteps \ No newline at end of file + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 3dcaa57d0..a68cc6170 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -142,10 +142,10 @@ def load_dataset(self, pipeline, params, train_states): ) feature_description = { - "pixel_values": tf.io.FixedLenFeature([], tf.string), - "input_ids": tf.io.FixedLenFeature([], tf.string), - "prompt_embeds": tf.io.FixedLenFeature([], tf.string), - "text_embeds": tf.io.FixedLenFeature([], tf.string), + "pixel_values": tf.io.FixedLenFeature([], tf.string), + "input_ids": tf.io.FixedLenFeature([], tf.string), + "prompt_embeds": tf.io.FixedLenFeature([], tf.string), + "text_embeds": tf.io.FixedLenFeature([], tf.string), } def prepare_sample(features): @@ -154,7 +154,12 @@ def prepare_sample(features): prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) - return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds} + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "prompt_embeds": prompt_embeds, + "text_embeds": text_embeds, + } data_iterator = make_data_iterator( config, @@ -165,7 +170,7 @@ def prepare_sample(features): tokenize_fn=tokenize_fn, image_transforms_fn=image_transforms_fn, feature_description=feature_description, - prepare_sample_fn=prepare_sample + prepare_sample_fn=prepare_sample, ) return data_iterator diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index c9dc4f34d..6af138bc7 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -23,21 +23,15 @@ import tensorflow as tf import jax.numpy as jnp import jax -import jax.tree_util as jtu -from flax import nnx +from flax import nnx from ..schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning -from ..schedulers import FlaxEulerDiscreteScheduler -from .. import max_utils, max_logging, train_utils, maxdiffusion_utils +from .. import max_utils, max_logging, train_utils from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) -from maxdiffusion.multihost_dataloading import _form_global_array from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan -from maxdiffusion.train_utils import ( - _tensorboard_writer_worker, - load_next_batch, - _metrics_queue -) +from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue) + def generate_sample(config, pipeline, filename_prefix): """ @@ -80,39 +74,41 @@ def load_dataset(self, mesh): # TODO - create a dataset config = self.config if config.dataset_type != "tfrecord" and not config.cache_latents_text_encoder_outputs: - raise ValueError("Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True") + raise ValueError( + "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" + ) feature_description = { - "latents" : tf.io.FixedLenFeature([], tf.string), - "encoder_hidden_states" : tf.io.FixedLenFeature([], tf.string), + "latents": tf.io.FixedLenFeature([], tf.string), + "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), } def prepare_sample(features): latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32) - return {"latents" : latents, "encoder_hidden_states" : encoder_hidden_states} - + return {"latents": latents, "encoder_hidden_states": encoder_hidden_states} + data_iterator = make_data_iterator( - config, - jax.process_index(), - jax.process_count(), - mesh, - self.global_batch_size, - feature_description=feature_description, - prepare_sample_fn=prepare_sample + config, + jax.process_index(), + jax.process_count(), + mesh, + self.global_batch_size, + feature_description=feature_description, + prepare_sample_fn=prepare_sample, ) return data_iterator def start_training(self): pipeline = self.load_checkpoint() - #del pipeline.vae + # del pipeline.vae # Generate a sample before training to compare against generated sample after training. - generate_sample(self.config, pipeline, filename_prefix='pre-training-') + generate_sample(self.config, pipeline, filename_prefix="pre-training-") mesh = pipeline.mesh data_iterator = self.load_dataset(mesh) - + # Load FlowMatch scheduler scheduler, scheduler_state = self.create_scheduler() pipeline.scheduler = scheduler @@ -124,11 +120,11 @@ def start_training(self): def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): graphdef, state = nnx.split((pipeline.transformer, optimizer)) - + writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) writer_thread.start() - + num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) @@ -189,12 +185,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera writer_thread.join() if writer: writer.flush() - + # load new state for trained tranformer graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state) - generate_sample(self.config, pipeline, filename_prefix='post-training-') + generate_sample(self.config, pipeline, filename_prefix="post-training-") def train_step(state, graphdef, scheduler_state, data, rng, scheduler): @@ -205,14 +201,14 @@ def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng): _, new_rng, timestep_rng = jax.random.split(rng, num=3) def loss_fn(model): - latents = data['latents'] - encoder_hidden_states = data['encoder_hidden_states'] + latents = data["latents"] + encoder_hidden_states = data["encoder_hidden_states"] bsz = latents.shape[0] timesteps = jax.random.randint( - timestep_rng, - (bsz,), - 0, - scheduler.config.num_train_timesteps, + timestep_rng, + (bsz,), + 0, + scheduler.config.num_train_timesteps, ) noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) @@ -224,10 +220,10 @@ def loss_fn(model): is_uncond=jnp.array(False, dtype=jnp.bool_), slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) - + training_target = scheduler.training_target(latents, noise, timesteps) training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) - loss = ((training_target - model_pred) ** 2) + loss = (training_target - model_pred) ** 2 loss = loss * training_weight loss = jnp.mean(loss) From e53ee2bfc191084857e332c9b21a71a99756f38a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 26 Jun 2025 19:23:55 +0000 Subject: [PATCH 4/4] fix input pipeline tests --- .../input_pipeline/input_pipeline_interface.py | 2 +- .../tests/input_pipeline_interface_test.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index d0be27144..0c1c68602 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -63,7 +63,7 @@ def make_data_iterator( config.dataset_type == "tfrecord" and config.cache_latents_text_encoder_outputs and feature_description is None - or prepare_sample_fn is None + and prepare_sample_fn is None ): raise ValueError( f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True." diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 9afd22674..79e7a0891 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -506,7 +506,23 @@ def test_make_laion_tfrecord_iterator(self): from_pt=config.from_pt, ) - train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) + feature_description = { + "moments": tf.io.FixedLenFeature([], tf.string), + "clip_embeddings": tf.io.FixedLenFeature([], tf.string), + } + + def _parse_tfrecord_fn(example): + return tf.io.parse_single_example(example, feature_description) + + train_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + global_batch_size, + feature_description=feature_description, + prepare_sample_fn=_parse_tfrecord_fn, + ) data = next(train_iterator) device_count = jax.device_count()