From 5ad4ef4a4d234d35f7a1393244f5abb80815cb94 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Jul 2025 16:37:24 +0000 Subject: [PATCH 01/14] use local_devices instead of devices which defaults to first machine's devices. --- src/maxdiffusion/models/wan/wan_utils.py | 34 ++++++++++++------- .../pipelines/wan/wan_pipeline.py | 4 ++- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 2ceb0f7e6..5a27591d6 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key): return renamed_pt_key -def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_fusionx_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] with jax.default_device(device): if hf_download: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors") @@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di return flax_state_dict -def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_causvid_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] with jax.default_device(device): if hf_download: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt") @@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di return flax_state_dict -def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): +def load_wan_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: - return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: - return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) else: - return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) -def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_base_wan_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] subfolder = "transformer" filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False @@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index abf449291..9ca2e03b9 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -95,7 +95,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # 4. Load pretrained weights and move them to device using the state shardings from (3) above. # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu") + params = load_wan_transformer( + config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] + ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value From a769016b4565ae4e2e2b8ad6b8e0ef82d06a23f8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Jul 2025 22:23:41 +0000 Subject: [PATCH 02/14] accept gcs path for data. --- src/maxdiffusion/input_pipeline/_tfds_data_processing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index ce0ae5169..802786a61 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -117,13 +117,16 @@ def make_tfrecord_iterator( check out preparation script maxdiffusion/pedagogical_examples/to_tfrecords.py """ - # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + + # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. + is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) + if ( config.cache_latents_text_encoder_outputs - and os.path.isdir(config.dataset_save_location) + and is_dataset_dir_valid and "load_tfrecord_cached" in config.get_keys() and config.load_tfrecord_cached ): From 3f2a800479748fa5357a97ad50432c6a21808a06 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 28 Jul 2025 23:11:37 +0000 Subject: [PATCH 03/14] linting --- src/maxdiffusion/input_pipeline/_tfds_data_processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 802786a61..562d5c718 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -120,10 +120,10 @@ def make_tfrecord_iterator( # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. - + # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) - + if ( config.cache_latents_text_encoder_outputs and is_dataset_dir_valid From 28dbe570678e9f6964affd901a773f0a8155253e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 29 Jul 2025 23:01:45 +0000 Subject: [PATCH 04/14] add remat policy. Remove sharding for shard_map splash to lower memory footprint --- src/maxdiffusion/configs/base_wan_14b.yml | 8 +++ src/maxdiffusion/models/attention_flax.py | 38 ++++------ .../models/gradient_checkpoint.py | 70 +++++++++++++++++++ .../wan/transformers/transformer_wan.py | 7 +- .../pipelines/wan/wan_pipeline.py | 1 + 5 files changed, 99 insertions(+), 25 deletions(-) create mode 100644 src/maxdiffusion/models/gradient_checkpoint.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 911127896..a34ac6225 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -182,6 +182,14 @@ transform_images_num_proc: 4 reuse_example_batch: False enable_data_shuffling: True +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters +remat_policy: "NONE" + # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 # enables one replica to read the ckpt then broadcast to the rest diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3099a5bc0..530cd720f 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -187,31 +187,12 @@ def _tpu_flash_attention( value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) + flash_axis_names_splash_kernel: AxisNames = (HEAD, KV_LENGTH) axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) shard_head_size = mesh.shape["tensor"] - @functools.partial( - jax.jit, - static_argnames=["multi_head_mask", "shard_head_size"], - ) - def wrap_splash_kernel(multi_head_mask, shard_head_size=1): - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=shard_head_size, # the sizes of the axis is sharding over heads - q_seq_shards=1, # the sizes of the axis is sharding over seq_len - block_sizes=block_sizes, - ) - return splash_kernel - - mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) - - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) - segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) - @functools.partial( shard_map.shard_map, mesh=mesh, @@ -219,12 +200,21 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): q_axis_names, kv_axis_names, kv_axis_names, - segment_axis_names_splash_kernel, ), out_specs=q_axis_names, check_rep=False, ) - def wrap_flash_attention(query, key, value, splash_kernel): + def wrap_flash_attention(query, key, value): + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + # make_splash_mha is wrapped around shardmap and seq and head is already + # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + ) attention_output = jax.vmap(splash_kernel)(query, key, value) return attention_output @@ -236,7 +226,7 @@ def wrap_flash_attention(query, key, value, splash_kernel): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - x = wrap_flash_attention(query, key, value, splash_kernel) + x = wrap_flash_attention(query, key, value) x = x[:, :, :query_seq_len, :kv_size] x = _reshape_heads_to_head_dim(x) @@ -632,7 +622,7 @@ def __init__( use_memory_efficient_attention: bool = False, split_head_dim: bool = False, attention_kernel: str = "flash", - flash_min_seq_length: int = 4096, + flash_min_seq_length: int = 0, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, dtype: jnp.dtype = jnp.float32, diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py new file mode 100644 index 000000000..2fe72b8e8 --- /dev/null +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -0,0 +1,70 @@ +from enum import Enum, auto +from typing import Optional + +import jax +from flax import nnx + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + +# This class only works with NNX modules. +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nnx.Module) -> nnx.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nnx.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index b1ae70b7a..b7cd35462 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -31,6 +31,7 @@ ) from ...normalization_flax import FP32LayerNorm from ...attention_flax import FlaxWanAttention +from ...gradient_checkpoint import GradientCheckpointType BlockSizes = common_types.BlockSizes @@ -356,6 +357,7 @@ def __init__( weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, attention: str = "dot_product", + remat_policy: str = "None" ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -417,6 +419,8 @@ def init_block(rngs): attention=attention, ) + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + self.blocks = init_block(rngs) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) @@ -469,8 +473,9 @@ def scan_fn(carry, block): return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + rematted_block_forward = self.gradient_checkpoint.apply(scan_fn) final_carry = nnx.scan( - scan_fn, + rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=nnx.Carry, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9ca2e03b9..7093d2f4a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -78,6 +78,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["attention"] = config.attention wan_config["precision"] = get_precision(config) wan_config["flash_block_sizes"] = get_flash_block_sizes(config) + wan_config["remat_policy"] = config.remat_policy # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. From 00bf1cd6d0e77acff021b1e930162b1f45526efc Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 30 Jul 2025 20:53:12 +0000 Subject: [PATCH 05/14] add support for < 1 batch item per device. --- src/maxdiffusion/configs/base_wan_14b.yml | 2 + .../input_pipeline/_tfds_data_processing.py | 2 +- src/maxdiffusion/models/attention_flax.py | 9 ++--- src/maxdiffusion/multihost_dataloading.py | 39 ++++++++++++++----- .../pipelines/wan/wan_pipeline.py | 1 + src/maxdiffusion/pyconfig.py | 2 + src/maxdiffusion/trainers/wan_trainer.py | 3 +- 7 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index a34ac6225..1333ff0cc 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -54,6 +54,7 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +flash_min_seq_length: 4096 flash_block_sizes: {} # Use on v6e @@ -131,6 +132,7 @@ logical_axis_rules: [ ['activation_batch', 'data'], ['mlp','tensor'], ['embed','fsdp'], + ['heads', 'tensor'], ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 562d5c718..87b68d1a3 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -105,7 +105,7 @@ def _parse_tfrecord_fn(example): ) # This wraps the tf.data.Dataset for use in the multi-host JAX environment. - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) + train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh, config.global_batch_size) return train_iter diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 530cd720f..0106f8dbd 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -187,11 +187,6 @@ def _tpu_flash_attention( value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) - flash_axis_names_splash_kernel: AxisNames = (HEAD, KV_LENGTH) - axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) - named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) - - shard_head_size = mesh.shape["tensor"] @functools.partial( shard_map.shard_map, @@ -215,6 +210,9 @@ def wrap_flash_attention(query, key, value): q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, ) + # jax.debug.print("query.shape: {x}", x=query.shape) + # jax.debug.print("key.shape: {x}", x=key.shape) + # jax.debug.print("value.shape: {x}", x=value.shape) attention_output = jax.vmap(splash_kernel)(query, key, value) return attention_output @@ -799,6 +797,7 @@ def __call__( query_proj = _unflatten_heads(query_proj, self.heads) key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) + # output of _unflatten_heads Batch, heads, seq_len, head_dim query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) diff --git a/src/maxdiffusion/multihost_dataloading.py b/src/maxdiffusion/multihost_dataloading.py index 4be0ba8d9..73ce04f23 100644 --- a/src/maxdiffusion/multihost_dataloading.py +++ b/src/maxdiffusion/multihost_dataloading.py @@ -37,20 +37,23 @@ def _build_global_shape_and_sharding( - local_shape: tuple[int, ...], global_mesh: Mesh + local_shape: tuple[int, ...], global_mesh: Mesh, global_batch_size: int = 0 ) -> tuple[tuple[int, ...], NamedSharding]: - sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) + #Handle sharding for setting a gbs < jax.device_count + if global_batch_size > 0: + sharding = NamedSharding(global_mesh, PartitionSpec(*global_mesh.axis_names)) + else: + sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:] - return global_shape, sharding -def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: +def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array: """Put local sharded array into local devices""" - global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) + global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh, global_batch_size) try: - local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) + local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=split_axis_index) except ValueError as array_split_error: raise ValueError( f"Unable to put to devices shape {array.shape} with " @@ -62,7 +65,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) -def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Array: +def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array: """Splits the host loaded data equally over all devices.""" SLEEP_TIME = 10 @@ -83,7 +86,7 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Ar if not loaded_data_success: local_data = local_dataset.next() - input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data) + input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh, global_batch_size=global_batch_size, split_axis_index=split_axis_index), local_data) return input_gdas @@ -91,9 +94,25 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Ar class MultiHostDataLoadIterator: """fold get_next_batch_sharded into a iterator class""" - def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh): + def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh, global_batch_size: int = 0): self.global_mesh = global_mesh self.dataloader = dataloader + # Handles sharding for when gbs < number of devices + self.global_batch_size = global_batch_size + # Use the correct axis for splitting the data across when using global_batch_size + split_axis_name = max(global_mesh.shape, key=global_mesh.shape.get) + split_axis_index = 0 + if global_batch_size > 0: + max_logging.log(f"global_batch_size was set to {global_batch_size}, splitting data across {split_axis_name}.") + if split_axis_name == "data": + split_axis_index = 0 + elif split_axis_name == "fsdp": + split_axis_index = 1 + elif split_axis_name == "tensor": + split_axis_index = 2 + else: + raise ValueError(f"Could not find {split_axis_name} to split data over.") + self.split_axis_index = split_axis_index if isinstance(self.dataloader, tf.data.Dataset): self.local_iterator = self.dataloader.as_numpy_iterator() elif isinstance(self.dataloader, Iterable): @@ -114,4 +133,4 @@ def __iter__(self): return self def __next__(self): - return get_next_batch_sharded(self.local_iterator, self.global_mesh) + return get_next_batch_sharded(self.local_iterator, self.global_mesh, self.global_batch_size, self.split_axis_index) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7093d2f4a..4cc1a7188 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -79,6 +79,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["precision"] = get_precision(config) wan_config["flash_block_sizes"] = get_flash_block_sizes(config) wan_config["remat_policy"] = config.remat_policy + wan_config["flash_min_seq_length"] = config.flash_min_seq_length # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 92dd2a992..104cd7d99 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -181,6 +181,8 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + if "global_batch_size" not in raw_keys.keys(): + raw_keys["global_batch_size"] = 0 def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index d568709f0..8e3a26111 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -69,7 +69,8 @@ def __init__(self, config): if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") - self.global_batch_size = self.config.per_device_batch_size * jax.device_count() + #self.global_batch_size = self.config.per_device_batch_size * jax.device_count() + self.global_batch_size = config.global_batch_size if config.global_batch_size > 0 else config.per_device_batch_size * jax.device_count() def post_training_steps(self, pipeline, params, train_states, msg=""): pass From d3b50c84ccbb44c528c39087991b9754e99c5e34 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 31 Jul 2025 00:09:56 +0000 Subject: [PATCH 06/14] using train state instead. --- .../checkpointing/wan_checkpointer.py | 2 +- .../wan/transformers/transformer_wan.py | 10 --- src/maxdiffusion/trainers/wan_trainer.py | 68 +++++++++++++------ 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 5f64d4880..8704b0af8 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -42,7 +42,7 @@ def _create_optimizer(self, model, config, learning_rate): learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) tx = max_utils.create_optimizer(config, learning_rate_scheduler) - return nnx.Optimizer(model, tx), learning_rate_scheduler + return tx, learning_rate_scheduler def load_wan_configs_from_orbax(self, step): max_logging.log("Restoring stable diffusion configs") diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index b7cd35462..d781ebb48 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -374,16 +374,6 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), - ( - None, - None, - None, - None, - "conv_out", - ), - ), ) # 2. Condition embeddings diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 8e3a26111..f5e3ca2a5 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -23,6 +23,7 @@ import tensorflow as tf import jax.numpy as jnp import jax +from jax.sharding import PartitionSpec as P from flax import nnx from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning @@ -34,7 +35,16 @@ from maxdiffusion.video_processor import VideoProcessor from maxdiffusion.utils import load_video from skimage.metrics import structural_similarity as ssim +from flax.training import train_state +class TrainState(train_state.TrainState): + graphdef: nnx.GraphDef + rest_of_state: nnx.State + +def _to_array(x): + if not isinstance(x, jax.Array): + x = jnp.asarray(x) + return x def generate_sample(config, pipeline, filename_prefix): """ @@ -85,6 +95,14 @@ def create_scheduler(self): def calculate_tflops(self, pipeline): max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") return 0 + + def get_data_shardings(self, mesh): + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) + data_sharding = { + "latents" : data_sharding, + "encoder_hidden_states" : data_sharding + } + return data_sharding def load_dataset(self, mesh): # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 @@ -136,9 +154,7 @@ def start_training(self): scheduler, scheduler_state = self.create_scheduler() pipeline.scheduler = scheduler pipeline.scheduler_state = scheduler_state - optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) - # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) @@ -146,14 +162,28 @@ def start_training(self): print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): - - graphdef, state = nnx.split((pipeline.transformer, optimizer)) + mesh = pipeline.mesh + graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) + + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + state = TrainState.create( + apply_fn=graphdef.apply, + params=params, + tx=optimizer, + graphdef=graphdef, + rest_of_state=rest_of_state + ) + state = jax.tree.map(_to_array, state) + state_spec = nnx.get_partition_spec(state) + state = jax.lax.with_sharding_constraint(state, state_spec) + state_shardings = nnx.get_named_sharding(state, mesh) + data_shardings = self.get_data_shardings(mesh) writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) writer_thread.start() - num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0]) + num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) max_utils.add_config_to_summary_writer(self.config, writer) @@ -164,9 +194,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}") max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") - state = state.to_pure_dict() p_train_step = jax.jit( functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config), + in_shardings = (state_shardings, data_shardings, None, None), + out_shardings = (state_shardings, None, None, None), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -195,7 +226,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( self.config.logical_axis_rules ): - state, scheduler_state, train_metric, rng = p_train_step(state, graphdef, scheduler_state, example_batch, rng) + state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) train_metric["scalar"]["learning/loss"].block_until_ready() last_step_completion = datetime.datetime.now() @@ -215,19 +246,19 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera writer.flush() # load new state for trained tranformer - graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) - pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state) + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) return pipeline -def train_step(state, graphdef, scheduler_state, data, rng, scheduler, config): - return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config) +def train_step(state, data, rng, scheduler_state, scheduler, config): + return step_optimizer(state, data, rng, scheduler_state, scheduler, config) -def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config): +def step_optimizer(state, data, rng, scheduler_state, scheduler, config): _, new_rng, timestep_rng = jax.random.split(rng, num=3) - def loss_fn(model): + def loss_fn(params): + model = nnx.merge(state.graphdef, params, state.rest_of_state) latents = data["latents"].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) bsz = latents.shape[0] @@ -253,11 +284,8 @@ def loss_fn(model): loss = jnp.mean(loss) return loss - - model, optimizer = nnx.merge(graphdef, state) - loss, grads = nnx.value_and_grad(loss_fn)(model) - optimizer.update(grads) - state = nnx.state((model, optimizer)) - state = state.to_pure_dict() + grad_fn = nnx.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + new_state = state.apply_gradients(grads=grads) metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - return state, scheduler_state, metrics, new_rng + return new_state, scheduler_state, metrics, new_rng From 340d7c463b655ce780229cdc88bd8cb415582aff Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 31 Jul 2025 01:03:45 +0000 Subject: [PATCH 07/14] set data sharding correctly for gbs < 1 --- src/maxdiffusion/trainers/wan_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index f5e3ca2a5..a53e87f3b 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -97,7 +97,7 @@ def calculate_tflops(self, pipeline): return 0 def get_data_shardings(self, mesh): - data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding[0])) data_sharding = { "latents" : data_sharding, "encoder_hidden_states" : data_sharding @@ -146,7 +146,7 @@ def start_training(self): # del pipeline.vae # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + #pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") mesh = pipeline.mesh data_iterator = self.load_dataset(mesh) From deb686dcddeaf09aba4a47a7f9a6ce4a4d0b891f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 31 Jul 2025 20:51:03 +0000 Subject: [PATCH 08/14] revert global_batch_size change. --- .../input_pipeline/_tfds_data_processing.py | 2 +- src/maxdiffusion/multihost_dataloading.py | 39 +++++-------------- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 87b68d1a3..562d5c718 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -105,7 +105,7 @@ def _parse_tfrecord_fn(example): ) # This wraps the tf.data.Dataset for use in the multi-host JAX environment. - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh, config.global_batch_size) + train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter diff --git a/src/maxdiffusion/multihost_dataloading.py b/src/maxdiffusion/multihost_dataloading.py index 73ce04f23..26b734f23 100644 --- a/src/maxdiffusion/multihost_dataloading.py +++ b/src/maxdiffusion/multihost_dataloading.py @@ -37,23 +37,20 @@ def _build_global_shape_and_sharding( - local_shape: tuple[int, ...], global_mesh: Mesh, global_batch_size: int = 0 + local_shape: tuple[int, ...], global_mesh: Mesh ) -> tuple[tuple[int, ...], NamedSharding]: - #Handle sharding for setting a gbs < jax.device_count - if global_batch_size > 0: - sharding = NamedSharding(global_mesh, PartitionSpec(*global_mesh.axis_names)) - else: - sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) + sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:] + return global_shape, sharding -def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array: +def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: """Put local sharded array into local devices""" - global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh, global_batch_size) + global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) try: - local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=split_axis_index) + local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) except ValueError as array_split_error: raise ValueError( f"Unable to put to devices shape {array.shape} with " @@ -65,7 +62,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) -def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array: +def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Array: """Splits the host loaded data equally over all devices.""" SLEEP_TIME = 10 @@ -86,7 +83,7 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_ba if not loaded_data_success: local_data = local_dataset.next() - input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh, global_batch_size=global_batch_size, split_axis_index=split_axis_index), local_data) + input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data) return input_gdas @@ -94,25 +91,9 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_ba class MultiHostDataLoadIterator: """fold get_next_batch_sharded into a iterator class""" - def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh, global_batch_size: int = 0): + def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh): self.global_mesh = global_mesh self.dataloader = dataloader - # Handles sharding for when gbs < number of devices - self.global_batch_size = global_batch_size - # Use the correct axis for splitting the data across when using global_batch_size - split_axis_name = max(global_mesh.shape, key=global_mesh.shape.get) - split_axis_index = 0 - if global_batch_size > 0: - max_logging.log(f"global_batch_size was set to {global_batch_size}, splitting data across {split_axis_name}.") - if split_axis_name == "data": - split_axis_index = 0 - elif split_axis_name == "fsdp": - split_axis_index = 1 - elif split_axis_name == "tensor": - split_axis_index = 2 - else: - raise ValueError(f"Could not find {split_axis_name} to split data over.") - self.split_axis_index = split_axis_index if isinstance(self.dataloader, tf.data.Dataset): self.local_iterator = self.dataloader.as_numpy_iterator() elif isinstance(self.dataloader, Iterable): @@ -133,4 +114,4 @@ def __iter__(self): return self def __next__(self): - return get_next_batch_sharded(self.local_iterator, self.global_mesh, self.global_batch_size, self.split_axis_index) + return get_next_batch_sharded(self.local_iterator, self.global_mesh) \ No newline at end of file From 66c85fe254d46a0635481f52593412232ddf2a33 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 31 Jul 2025 22:43:30 +0000 Subject: [PATCH 09/14] add shardings to projection and patch embedding. --- src/maxdiffusion/configs/base_wan_14b.yml | 2 +- .../models/wan/transformers/transformer_wan.py | 17 +++++++++++++++++ src/maxdiffusion/trainers/wan_trainer.py | 4 ++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1333ff0cc..7e506e951 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -136,7 +136,7 @@ logical_axis_rules: [ ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], - ['conv_in', 'fsdp'], + ['conv_out', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index d781ebb48..b16c57efc 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -171,6 +171,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + "mlp", + "embed", + ), + ), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -374,6 +381,16 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + ( + None, + None, + None, + None, + "conv_out" + ), + ), ) # 2. Condition embeddings diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a53e87f3b..ce54b49d0 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -80,7 +80,7 @@ def __init__(self, config): raise ValueError("this script currently doesn't support training text_encoders") #self.global_batch_size = self.config.per_device_batch_size * jax.device_count() - self.global_batch_size = config.global_batch_size if config.global_batch_size > 0 else config.per_device_batch_size * jax.device_count() + self.global_batch_size = config.per_device_batch_size * jax.device_count() def post_training_steps(self, pipeline, params, train_states, msg=""): pass @@ -97,7 +97,7 @@ def calculate_tflops(self, pipeline): return 0 def get_data_shardings(self, mesh): - data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding[0])) + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) data_sharding = { "latents" : data_sharding, "encoder_hidden_states" : data_sharding From 3b729afedcffac338cc133dc8ae71a27fb7a2686 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 31 Jul 2025 23:30:15 +0000 Subject: [PATCH 10/14] readd generating samples before training. --- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index ce54b49d0..1714a054c 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -146,7 +146,7 @@ def start_training(self): # del pipeline.vae # Generate a sample before training to compare against generated sample after training. - #pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") mesh = pipeline.mesh data_iterator = self.load_dataset(mesh) From 42c3920da97bf3ab8439d0741f305e714744138e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 31 Jul 2025 23:51:53 +0000 Subject: [PATCH 11/14] linting. --- .../checkpointing/wan_checkpointer.py | 1 - src/maxdiffusion/models/attention_flax.py | 2 +- .../models/gradient_checkpoint.py | 3 ++- .../wan/transformers/transformer_wan.py | 10 ++------- src/maxdiffusion/multihost_dataloading.py | 2 +- src/maxdiffusion/trainers/wan_trainer.py | 22 ++++++++----------- 6 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 8704b0af8..8f1e2654e 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -15,7 +15,6 @@ """ from abc import ABC -from flax import nnx from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline import WanPipeline from .. import max_logging, max_utils diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 0106f8dbd..b7e32523e 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -202,7 +202,7 @@ def _tpu_flash_attention( def wrap_flash_attention(query, key, value): mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - # make_splash_mha is wrapped around shardmap and seq and head is already + # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 2fe72b8e8..ec83c4657 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -6,7 +6,8 @@ SKIP_GRADIENT_CHECKPOINT_KEY = "skip" -# This class only works with NNX modules. + +# This class only works with NNX modules. class GradientCheckpointType(Enum): """ Defines the type of the gradient checkpoint we will have diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index b16c57efc..73ecfeb8a 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -364,7 +364,7 @@ def __init__( weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, attention: str = "dot_product", - remat_policy: str = "None" + remat_policy: str = "None", ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -383,13 +383,7 @@ def __init__( precision=precision, kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), - ( - None, - None, - None, - None, - "conv_out" - ), + (None, None, None, None, "conv_out"), ), ) diff --git a/src/maxdiffusion/multihost_dataloading.py b/src/maxdiffusion/multihost_dataloading.py index 26b734f23..4be0ba8d9 100644 --- a/src/maxdiffusion/multihost_dataloading.py +++ b/src/maxdiffusion/multihost_dataloading.py @@ -114,4 +114,4 @@ def __iter__(self): return self def __next__(self): - return get_next_batch_sharded(self.local_iterator, self.global_mesh) \ No newline at end of file + return get_next_batch_sharded(self.local_iterator, self.global_mesh) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 1714a054c..171500267 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -37,15 +37,18 @@ from skimage.metrics import structural_similarity as ssim from flax.training import train_state + class TrainState(train_state.TrainState): graphdef: nnx.GraphDef rest_of_state: nnx.State + def _to_array(x): if not isinstance(x, jax.Array): x = jnp.asarray(x) return x + def generate_sample(config, pipeline, filename_prefix): """ Generates a video to validate training did not corrupt the model @@ -79,7 +82,6 @@ def __init__(self, config): if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") - #self.global_batch_size = self.config.per_device_batch_size * jax.device_count() self.global_batch_size = config.per_device_batch_size * jax.device_count() def post_training_steps(self, pipeline, params, train_states, msg=""): @@ -95,13 +97,10 @@ def create_scheduler(self): def calculate_tflops(self, pipeline): max_logging.log("WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...") return 0 - + def get_data_shardings(self, mesh): data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) - data_sharding = { - "latents" : data_sharding, - "encoder_hidden_states" : data_sharding - } + data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding} return data_sharding def load_dataset(self, mesh): @@ -167,11 +166,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): state = TrainState.create( - apply_fn=graphdef.apply, - params=params, - tx=optimizer, - graphdef=graphdef, - rest_of_state=rest_of_state + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state ) state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state) @@ -196,8 +191,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera p_train_step = jax.jit( functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config), - in_shardings = (state_shardings, data_shardings, None, None), - out_shardings = (state_shardings, None, None, None), + in_shardings=(state_shardings, data_shardings, None, None), + out_shardings=(state_shardings, None, None, None), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -284,6 +279,7 @@ def loss_fn(params): loss = jnp.mean(loss) return loss + grad_fn = nnx.value_and_grad(loss_fn) loss, grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) From 6d0a2c39f77066e48341c08fb4cf8503591fa2d8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 1 Aug 2025 15:28:22 +0000 Subject: [PATCH 12/14] add license headers, remove commented out code. --- src/maxdiffusion/models/attention_flax.py | 3 --- src/maxdiffusion/models/gradient_checkpoint.py | 16 ++++++++++++++++ src/maxdiffusion/models/wan/wan_utils.py | 16 ++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b7e32523e..868beae71 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -210,9 +210,6 @@ def wrap_flash_attention(query, key, value): q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, ) - # jax.debug.print("query.shape: {x}", x=query.shape) - # jax.debug.print("key.shape: {x}", x=key.shape) - # jax.debug.print("value.shape: {x}", x=value.shape) attention_output = jax.vmap(splash_kernel)(query, key, value) return attention_output diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index ec83c4657..3f4476bbe 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -1,3 +1,19 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + from enum import Enum, auto from typing import Optional diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 5a27591d6..26eff1137 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,3 +1,19 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + import os import json import torch From 54e672e1d363b80443386c994751d24d79b9fab0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 1 Aug 2025 21:10:15 +0000 Subject: [PATCH 13/14] shard proj and proj_out across fsdp. --- .../wan/transformers/transformer_wan.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 73ecfeb8a..617719125 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -18,6 +18,7 @@ import math import jax import jax.numpy as jnp +from jax.sharding import PartitionSpec from flax import nnx import numpy as np from .... import common_types @@ -174,10 +175,14 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( + None, "mlp", "embed", ), ), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, (None, "embed") + ) ) def __call__(self, x: jax.Array) -> jax.Array: @@ -225,8 +230,9 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", + None, "embed", + "mlp", ), ), ) @@ -314,7 +320,18 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.scale_shift_table + temb), 6, axis=1 ) - + + # shift_msa = jax.lax.with_sharding_constraint(shift_msa, PartitionSpec("data", None, "tensor")) + # scale_msa = jax.lax.with_sharding_constraint(scale_msa, PartitionSpec("data", None, "tensor")) + # gate_msa = jax.lax.with_sharding_constraint(gate_msa, PartitionSpec("data", None, "tensor")) + # c_shift_msa = jax.lax.with_sharding_constraint(c_shift_msa, PartitionSpec("data", None, "tensor")) + # c_scale_msa = jax.lax.with_sharding_constraint(c_scale_msa, PartitionSpec("data", None, "tensor")) + # c_gate_msa = jax.lax.with_sharding_constraint(c_gate_msa, PartitionSpec("data", None, "tensor")) + + # hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + # encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", None, "tensor")) + # temb = jax.lax.with_sharding_constraint(temb, PartitionSpec("data", None, "tensor")) + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( @@ -457,6 +474,7 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) + #rotary_emb = jax.lax.with_sharding_constraint(rotary_emb, PartitionSpec(None, None, "fsdp", None)) hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -464,10 +482,9 @@ def __call__( timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) - + #hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - def scan_fn(carry, block): hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) From 7882ee3091b335d309bc1c5de81e61e58d52cceb Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 5 Aug 2025 05:12:27 +0000 Subject: [PATCH 14/14] support fractional batch sizes, shard projections. --- src/maxdiffusion/configs/base_wan_14b.yml | 5 ++-- src/maxdiffusion/generate_wan.py | 12 ++------ src/maxdiffusion/models/attention_flax.py | 6 +++- .../models/gradient_checkpoint.py | 6 ++++ .../wan/transformers/transformer_wan.py | 28 ++++++------------- src/maxdiffusion/models/wan/wan_utils.py | 1 + .../pipelines/wan/wan_pipeline.py | 3 +- src/maxdiffusion/pyconfig.py | 17 +++++++++-- src/maxdiffusion/trainers/wan_trainer.py | 9 +++--- 9 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 7e506e951..b552b0621 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -127,9 +127,10 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], + ['activation_batch', 'data'], ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], - ['activation_batch', 'data'], ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], @@ -206,7 +207,7 @@ max_train_steps: 1500 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 1 +per_device_batch_size: 1.0 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d3c8d47cf..a9bcf366c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -29,15 +29,9 @@ def run(config, pipeline=None, filename_prefix=""): pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() - # If global_batch_size % jax.device_count is not 0, use FSDP sharding. - global_batch_size = config.global_batch_size - if global_batch_size != 0: - batch_multiplier = global_batch_size - else: - batch_multiplier = jax.device_count() * config.per_device_batch_size - - prompt = [config.prompt] * batch_multiplier - negative_prompt = [config.negative_prompt] * batch_multiplier + # Using global_batch_size_to_train_on so not to create more config variables + prompt = [config.prompt] * config.global_batch_size_to_train_on + negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 868beae71..fe86e08c4 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -18,6 +18,7 @@ import flax.linen as nn from flax import nnx import jax +from jax.ad_checkpoint import checkpoint_name from jax.sharding import PartitionSpec import jax.numpy as jnp from jax.experimental import shard_map @@ -797,10 +798,13 @@ def __call__( # output of _unflatten_heads Batch, heads, seq_len, head_dim query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + query_proj = checkpoint_name(query_proj, "query_proj") + key_proj = checkpoint_name(key_proj, "key_proj") + value_proj = checkpoint_name(value_proj, "value_proj") attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) - + attn_output = checkpoint_name(attn_output, "attn_output") hidden_states = self.proj_attn(attn_output) return hidden_states diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 3f4476bbe..28f637c23 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -18,6 +18,7 @@ from typing import Optional import jax +from jax import checkpoint_policies as cp from flax import nnx SKIP_GRADIENT_CHECKPOINT_KEY = "skip" @@ -38,6 +39,7 @@ class GradientCheckpointType(Enum): NONE = auto() FULL = auto() MATMUL_WITHOUT_BATCH = auto() + ATTN = auto() @classmethod def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": @@ -63,6 +65,10 @@ def to_jax_policy(self): return SKIP_GRADIENT_CHECKPOINT_KEY case GradientCheckpointType.FULL: return None + case GradientCheckpointType.ATTN: + return cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host" + ) case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 617719125..6588929b1 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -180,9 +180,7 @@ def __init__( "embed", ), ), - bias_init=nnx.with_partitioning( - nnx.initializers.zeros, (None, "embed") - ) + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -314,24 +312,15 @@ def __init__( self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) key = rngs.params() - self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) + self.adaln_scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.scale_shift_table + temb), 6, axis=1 + (self.adaln_scale_shift_table + temb), 6, axis=1 ) - - # shift_msa = jax.lax.with_sharding_constraint(shift_msa, PartitionSpec("data", None, "tensor")) - # scale_msa = jax.lax.with_sharding_constraint(scale_msa, PartitionSpec("data", None, "tensor")) - # gate_msa = jax.lax.with_sharding_constraint(gate_msa, PartitionSpec("data", None, "tensor")) - # c_shift_msa = jax.lax.with_sharding_constraint(c_shift_msa, PartitionSpec("data", None, "tensor")) - # c_scale_msa = jax.lax.with_sharding_constraint(c_scale_msa, PartitionSpec("data", None, "tensor")) - # c_gate_msa = jax.lax.with_sharding_constraint(c_gate_msa, PartitionSpec("data", None, "tensor")) - - # hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) - # encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", None, "tensor")) - # temb = jax.lax.with_sharding_constraint(temb, PartitionSpec("data", None, "tensor")) - + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( @@ -474,7 +463,7 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) - #rotary_emb = jax.lax.with_sharding_constraint(rotary_emb, PartitionSpec(None, None, "fsdp", None)) + hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -482,9 +471,10 @@ def __call__( timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) - #hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") + def scan_fn(carry, block): hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 26eff1137..628207a9a 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -241,6 +241,7 @@ def load_base_wan_transformer( for pt_key, tensor in tensors.items(): renamed_pt_key = rename_key(pt_key) renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 4cc1a7188..8d2f2cd3b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -416,7 +416,8 @@ def __call__( ) data_sharding = NamedSharding(self.mesh, P()) - if len(prompt) % jax.device_count() == 0: + # Using global_batch_size_to_train_on so not to create more config variables + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) latents = jax.device_put(latents, data_sharding) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 104cd7d99..8e758d661 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -137,6 +137,18 @@ def wan_init(raw_keys): else: raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod + def calculate_global_batch_sizes(per_device_batch_size): + num_devices = len(jax.devices()) + if per_device_batch_size < 1: + # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 + global_batch_size_to_load = num_devices + else: + global_batch_size_to_load = int(num_devices * per_device_batch_size) + + global_batch_size_to_train_on = int(num_devices * per_device_batch_size) + return global_batch_size_to_load, global_batch_size_to_train_on + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" @@ -181,8 +193,9 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - if "global_batch_size" not in raw_keys.keys(): - raw_keys["global_batch_size"] = 0 + raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = ( + _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) + ) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 171500267..3b0b520bf 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -82,8 +82,6 @@ def __init__(self, config): if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") - self.global_batch_size = config.per_device_batch_size * jax.device_count() - def post_training_steps(self, pipeline, params, train_states, msg=""): pass @@ -133,7 +131,7 @@ def prepare_sample(features): jax.process_index(), jax.process_count(), mesh, - self.global_batch_size, + config.global_batch_size_to_load, feature_description=feature_description, prepare_sample_fn=prepare_sample, ) @@ -186,7 +184,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera if jax.process_index() == 0: max_logging.log("***** Running training *****") max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") - max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.config.global_batch_size_to_train_on}") max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") p_train_step = jax.jit( @@ -252,6 +250,9 @@ def train_step(state, data, rng, scheduler_state, scheduler, config): def step_optimizer(state, data, rng, scheduler_state, scheduler, config): _, new_rng, timestep_rng = jax.random.split(rng, num=3) + for k, v in data.items(): + data[k] = v[: config.global_batch_size_to_train_on, :] + def loss_fn(params): model = nnx.merge(state.graphdef, params, state.rest_of_state) latents = data["latents"].astype(config.weights_dtype)