diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 486c8ba3c..56fa47ca9 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -56,8 +56,9 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring flash_min_seq_length: 4096 +dropout: 0.1 flash_block_sizes: {} # Use on v6e @@ -193,8 +194,14 @@ enable_data_shuffling: True # FULL - means full gradient checkpoint, whenever possible (minimum memory usage) # MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, # except for ones that involve batch dimension - that means that all attention and projection -# layers will have gradient checkpoint, but not the backward with respect to the parameters +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 5d1785070..8463ebaa1 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -47,21 +47,24 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") + class CustomEncoder(json.JSONEncoder): - """ - Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes. - """ - def default(self, o): - # This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16" - if isinstance(o, type(jnp.dtype('bfloat16'))): - return str(o) - # Add fallbacks for other numpy types if needed - if isinstance(o, np.integer): - return int(o) - if isinstance(o, np.floating): - return float(o) - # Let the base class default method raise the TypeError for other types - return super().default(o) + """ + Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes. + """ + + def default(self, o): + # This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16" + if isinstance(o, type(jnp.dtype("bfloat16"))): + return str(o) + # Add fallbacks for other numpy types if needed + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + # Let the base class default method raise the TypeError for other types + return super().default(o) + class FrozenDict(OrderedDict): @@ -596,14 +599,14 @@ def to_json_saveable(value): config_dict.pop("quant", None) keys_to_remove = [] for key, value in config_dict.items(): - # Check the type of the value by its class name to avoid import issues - if type(value).__name__ == 'Rngs': - keys_to_remove.append(key) + # Check the type of the value by its class name to avoid import issues + if type(value).__name__ == "Rngs": + keys_to_remove.append(key) if keys_to_remove: - max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") - for key in keys_to_remove: - config_dict.pop(key) + max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") + for key in keys_to_remove: + config_dict.pop(key) try: diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 1dc1789a1..3530d5eb0 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -22,43 +22,47 @@ from maxdiffusion.utils import export_to_video from google.cloud import storage + def upload_video_to_gcs(output_dir: str, video_path: str): - """ - Uploads a local video file to a specified Google Cloud Storage bucket. - """ - try: - path_without_scheme = output_dir.removeprefix("gs://") - parts = path_without_scheme.split('/', 1) - bucket_name = parts[0] - folder_name = parts[1] if len(parts) > 1 else '' + """ + Uploads a local video file to a specified Google Cloud Storage bucket. + """ + try: + path_without_scheme = output_dir.removeprefix("gs://") + parts = path_without_scheme.split("/", 1) + bucket_name = parts[0] + folder_name = parts[1] if len(parts) > 1 else "" - storage_client = storage.Client() - bucket = storage_client.bucket(bucket_name) + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) - source_file_path = f"./{video_path}" - destination_blob_name = os.path.join(folder_name, "videos", video_path) + source_file_path = f"./{video_path}" + destination_blob_name = os.path.join(folder_name, "videos", video_path) - blob = bucket.blob(destination_blob_name) + blob = bucket.blob(destination_blob_name) - max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...") - blob.upload_from_filename(source_file_path) - max_logging.log(f"Upload complete {source_file_path}.") + max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...") + blob.upload_from_filename(source_file_path) + max_logging.log(f"Upload complete {source_file_path}.") + + except Exception as e: + max_logging.log(f"An error occurred: {e}") - except Exception as e: - max_logging.log(f"An error occurred: {e}") def delete_file(file_path: str): if os.path.exists(file_path): - try: - os.remove(file_path) - max_logging.log(f"Successfully deleted file: {file_path}") - except OSError as e: - max_logging.log(f"Error deleting file '{file_path}': {e}") + try: + os.remove(file_path) + max_logging.log(f"Successfully deleted file: {file_path}") + except OSError as e: + max_logging.log(f"Error deleting file '{file_path}': {e}") else: - max_logging.log(f"The file '{file_path}' does not exist.") + max_logging.log(f"The file '{file_path}' does not exist.") + jax.config.update("jax_use_shardy_partitioner", True) + def inference_generate_video(config, pipeline, filename_prefix=""): s0 = time.perf_counter() prompt = [config.prompt] * config.global_batch_size_to_train_on @@ -88,6 +92,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""): delete_file(f"./{video_path}") return + def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 885d59ef6..bb0428aa9 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -78,9 +78,18 @@ def make_tf_iterator( train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter + # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py def _make_tfrecord_iterator( - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description_fn, + prepare_sample_fn, + dataset_path, + is_training: bool, ): # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. @@ -93,10 +102,10 @@ def _make_tfrecord_iterator( # Determine whether to use the "cached" dataset, which requires externally # provided parsing functions, or the default one with its internal parsing logic. make_cached_tfrecord_iterator = ( - config.cache_latents_text_encoder_outputs - and is_dataset_dir_valid - and "load_tfrecord_cached" in config.get_keys() - and config.load_tfrecord_cached + config.cache_latents_text_encoder_outputs + and is_dataset_dir_valid + and "load_tfrecord_cached" in config.get_keys() + and config.load_tfrecord_cached ) feature_description = { @@ -121,42 +130,47 @@ def prepare_sample(features): if not is_training: num_eval_samples = 0 for _ in ds: - num_eval_samples += 1 + num_eval_samples += 1 remainder = num_eval_samples % global_batch_size if remainder != 0: - num_to_pad = global_batch_size - remainder - # Create a dataset of padding samples from the beginning - padding_ds = ds.take(num_to_pad) - # Add the padding samples to the end - ds = ds.concatenate(padding_ds) - max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") + num_to_pad = global_batch_size - remainder + # Create a dataset of padding samples from the beginning + padding_ds = ds.take(num_to_pad) + # Add the padding samples to the end + ds = ds.concatenate(padding_ds) + max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample ds = ( - ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(used_prepare_sample, num_parallel_calls=AUTOTUNE) + ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) + .map(used_prepare_sample, num_parallel_calls=AUTOTUNE) ) if is_training: ds = ( - ds.shuffle(global_batch_size * 10) - .batch(global_batch_size // dataloading_host_count, drop_remainder=True) - .repeat(-1) - .prefetch(AUTOTUNE) + ds.shuffle(global_batch_size * 10) + .batch(global_batch_size // dataloading_host_count, drop_remainder=True) + .repeat(-1) + .prefetch(AUTOTUNE) ) # For Evaluation else: - ds = ( - ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False) - .prefetch(AUTOTUNE) - ) + ds = ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False).prefetch(AUTOTUNE) iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh) return iter + def make_tfrecord_iterator( - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn, + is_training, ): """Iterator for TFRecord format. For Laion dataset, check out preparation script @@ -165,4 +179,14 @@ def make_tfrecord_iterator( # Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset. # TODO: refactor to support evaluation on all dataset format. dataset_path = config.train_data_dir if is_training else config.eval_data_dir - return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training) + return _make_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn, + dataset_path, + is_training, + ) diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index e7014bbc3..16477c35d 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -107,7 +107,7 @@ def make_data_iterator( global_batch_size, feature_description, prepare_sample_fn, - is_training + is_training, ) else: assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 25788fb69..3cbb0ccea 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -734,7 +734,7 @@ def __init__( # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. # Dims are [num_blocks, embed, heads] - kernel_axes = (None, "embed", "heads") + kernel_axes = ("embed", "heads") qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) self.query = nnx.Linear( @@ -747,10 +747,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - None, - "embed", - ), + ("embed",), ), ) @@ -764,10 +761,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - None, - "embed", - ), + ("embed",), ), ) @@ -781,10 +775,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - None, - "embed", - ), + ("embed",), ), ) @@ -792,12 +783,18 @@ def __init__( rngs=rngs, in_features=self.inner_dim, out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ("heads",), + ), ) + self.drop_out = nnx.Dropout(dropout) + self.norm_q = None self.norm_k = None if qk_norm is not None: @@ -808,10 +805,7 @@ def __init__( dtype=dtype, scale_init=nnx.with_partitioning( nnx.initializers.ones, - ( - None, - "norm", - ), + ("norm",), ), param_dtype=weights_dtype, ) @@ -822,10 +816,7 @@ def __init__( dtype=dtype, scale_init=nnx.with_partitioning( nnx.initializers.ones, - ( - None, - "norm", - ), + ("norm",), ), param_dtype=weights_dtype, ) @@ -847,7 +838,12 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup return xq_out, xk_out def __call__( - self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array = None, + rotary_emb: Optional[jax.Array] = None, + deterministic: bool = True, + rngs: nnx.Rngs = None, ) -> jax.Array: hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) @@ -877,6 +873,7 @@ def __call__( attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") hidden_states = self.proj_attn(attn_output) + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 28f637c23..a111ef717 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -39,7 +39,8 @@ class GradientCheckpointType(Enum): NONE = auto() FULL = auto() MATMUL_WITHOUT_BATCH = auto() - ATTN = auto() + OFFLOAD_MATMUL_WITHOUT_BATCH = auto() + CUSTOM = auto() @classmethod def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": @@ -56,7 +57,7 @@ def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": s = "none" return GradientCheckpointType[s.upper()] - def to_jax_policy(self): + def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = []): """ Converts the gradient checkpoint type to a jax policy """ @@ -65,14 +66,26 @@ def to_jax_policy(self): return SKIP_GRADIENT_CHECKPOINT_KEY case GradientCheckpointType.FULL: return None - case GradientCheckpointType.ATTN: - return cp.save_and_offload_only_these_names( - names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host" + case GradientCheckpointType.OFFLOAD_MATMUL_WITHOUT_BATCH: + return cp.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") + case GradientCheckpointType.CUSTOM: + policy = cp.save_and_offload_only_these_names( + names_which_can_be_saved=names_which_can_be_saved, + names_which_can_be_offloaded=names_which_can_be_offloaded, + offload_src="device", + offload_dst="pinned_host", ) + return policy case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nnx.Module) -> nnx.Module: + def apply( + self, + module: nnx.Module, + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + static_argnums=(), + ) -> nnx.Module: """ Applies a gradient checkpoint policy to a module if no policy is needed, it will return the module as is @@ -83,11 +96,7 @@ def apply(self, module: nnx.Module) -> nnx.Module: Returns: nn.Module: the module with the policy applied """ - policy = self.to_jax_policy() + policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded) if policy == SKIP_GRADIENT_CHECKPOINT_KEY: return module - return nnx.remat( # pylint: disable=invalid-name - module, - prevent_cse=False, - policy=policy, - ) + return nnx.remat(module, prevent_cse=False, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 6588929b1..718b5015e 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from jax.sharding import PartitionSpec +from jax.ad_checkpoint import checkpoint_name from flax import nnx import numpy as np from .... import common_types @@ -42,7 +43,7 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): t_dim = attention_head_dim - h_dim - w_dim freqs = [] for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False) + freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=False) freqs.append(freq) freqs = jnp.concatenate(freqs, axis=1) t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) @@ -175,12 +176,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - None, "mlp", "embed", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -217,6 +217,7 @@ def __init__( else: raise NotImplementedError(f"{activation_fn} is not implemented.") + self.drop_out = nnx.Dropout(dropout) self.proj_out = nnx.Linear( rngs=rngs, in_features=inner_dim, @@ -228,16 +229,17 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - None, "embed", "mlp", ), ), ) - def __call__(self, hidden_states: jax.Array) -> jax.Array: - hidden_states = self.act_fn(hidden_states) - return self.proj_out(hidden_states) + def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = checkpoint_name(hidden_states, "ffn_activation") + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): @@ -260,6 +262,7 @@ def __init__( weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, attention: str = "dot_product", + dropout: float = 0.0, ): # 1. Self-attention @@ -278,6 +281,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, attention_kernel=attention, + dropout=dropout, ) # 1. Cross-attention @@ -295,6 +299,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, attention_kernel=attention, + dropout=dropout, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -308,13 +313,24 @@ def __init__( dtype=dtype, weights_dtype=weights_dtype, precision=precision, + dropout=dropout, ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) key = rngs.params() - self.adaln_scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) + self.adaln_scale_shift_table = nnx.Param( + jax.random.normal(key, (1, 6, dim)) / dim**0.5, + ) - def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + temb: jax.Array, + rotary_emb: jax.Array, + deterministic: bool = True, + rngs: nnx.Rngs = None, + ): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb), 6, axis=1 ) @@ -324,18 +340,24 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, ) hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2( + hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs + ) hidden_states = hidden_states + attn_output # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) - ff_output = self.ffn(norm_hidden_states) + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype) return hidden_states @@ -356,6 +378,7 @@ def __init__( freq_dim: int = 256, ffn_dim: int = 13824, num_layers: int = 40, + dropout: float = 0.0, cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, @@ -371,6 +394,8 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", remat_policy: str = "None", + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -407,7 +432,7 @@ def __init__( # 3. Transformer blocks @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=0, out_axes=0) + @nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}) def init_block(rngs): return WanTransformerBlock( rngs=rngs, @@ -424,9 +449,12 @@ def init_block(rngs): weights_dtype=weights_dtype, precision=precision, attention=attention, + dropout=dropout, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + self.names_which_can_be_offloaded = names_which_can_be_offloaded + self.names_which_can_be_saved = names_which_can_be_saved self.blocks = init_block(rngs) @@ -454,6 +482,8 @@ def __call__( encoder_hidden_states_image: Optional[jax.Array] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = True, + rngs: nnx.Rngs = None, ) -> Union[jax.Array, Dict[str, jax.Array]]: batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size @@ -476,20 +506,23 @@ def __call__( raise NotImplementedError("img2vid is not yet implemented.") def scan_fn(carry, block): - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states_carry, rngs_carry = carry + hidden_states = block(hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry) + new_carry = (hidden_states, rngs_carry) + return new_carry, None - initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - rematted_block_forward = self.gradient_checkpoint.apply(scan_fn) - final_carry = nnx.scan( + rematted_block_forward = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded + ) + initial_carry = (hidden_states, rngs) + final_carry, _ = nnx.scan( rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), - out_axes=nnx.Carry, + out_axes=(nnx.Carry, 0), )(initial_carry, self.blocks) - hidden_states = final_carry[0] + hidden_states, _ = final_carry shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5dc98d087..f2c3701e3 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -86,7 +86,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["precision"] = get_precision(config) wan_config["flash_block_sizes"] = get_flash_block_sizes(config) wan_config["remat_policy"] = config.remat_policy + wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved + wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded wan_config["flash_min_seq_length"] = config.flash_min_seq_length + wan_config["dropout"] = config.dropout # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 33fc62f83..3bb5bd13c 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -120,6 +120,8 @@ def _load_kwargs(self, argv: list[str]): @staticmethod def wan_init(raw_keys): + if not any("layers_per_stage" in inner_tuple for inner_tuple in raw_keys["logical_axis_rules"]): + raw_keys["logical_axis_rules"] += (("layers_per_stage", None),) if "wan_transformer_pretrained_model_name_or_path" in raw_keys: transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] if transformer_pretrained_model_name_or_path == "": diff --git a/src/maxdiffusion/tests/gradient_checkpoint_test.py b/src/maxdiffusion/tests/gradient_checkpoint_test.py new file mode 100644 index 000000000..ca237d523 --- /dev/null +++ b/src/maxdiffusion/tests/gradient_checkpoint_test.py @@ -0,0 +1,60 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest +from absl.testing import absltest + +import jax + +from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType + + +class GradientCheckpointTest(unittest.TestCase): + """Unit test suite for GradientCheckpointType policies.""" + + def test_none_policy(self): + policy = GradientCheckpointType.from_str("NONE") + self.assertEqual(policy.to_jax_policy(), "skip") + + def test_full_policy(self): + policy = GradientCheckpointType.from_str("FULL") + self.assertIsNone(policy.to_jax_policy()) + + def test_matmul_without_batch_policy(self): + policy = GradientCheckpointType.from_str("MATMUL_WITHOUT_BATCH") + jax_policy_fn = policy.to_jax_policy() + self.assertIs(jax_policy_fn, jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims) + + def test_offload_matmul_without_batch_policy(self): + """ + Tests the offload variant by checking the class name of the return value. + """ + policy = GradientCheckpointType.from_str("OFFLOAD_MATMUL_WITHOUT_BATCH") + jax_policy_fn = policy.to_jax_policy() + self.assertTrue(callable(jax_policy_fn)) + + def test_custom_policy(self): + """ + Tests the custom policy by checking the class name of the return value. + """ + policy = GradientCheckpointType.from_str("CUSTOM") + names_to_offload = ["attn_output"] + jax_policy_fn = policy.to_jax_policy(names_which_can_be_offloaded=names_to_offload) + self.assertTrue(callable(jax_policy_fn)) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index b6a73ee5d..84efa064e 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import jax import jax.numpy as jnp @@ -276,7 +277,7 @@ def test_wan_model(self): ) assert dummy_output.shape == hidden_states_shape - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule") def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. @@ -292,11 +293,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_int8.quantization = "int8" provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) - mock_qt_rule.assert_called_once_with( - module_path='.*', - weight_qtype=jnp.int8, - act_qtype=jnp.int8 - ) + mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8) # Case 3: Quantization enabled, type 'fp8' mock_qt_rule.reset_mock() @@ -305,11 +302,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8.quantization = "fp8" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - mock_qt_rule.assert_called_once_with( - module_path='.*', - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn - ) + mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn) # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() @@ -320,15 +313,15 @@ def test_get_qt_provider(self, mock_qt_rule): provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) mock_qt_rule.assert_called_once_with( - module_path='.*', # Apply to all modules - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method = config_fp8_full.quantization_calibration_method, - act_calibration_method = config_fp8_full.quantization_calibration_method, - bwd_calibration_method = config_fp8_full.quantization_calibration_method, + module_path=".*", # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config_fp8_full.quantization_calibration_method, + act_calibration_method=config_fp8_full.quantization_calibration_method, + bwd_calibration_method=config_fp8_full.quantization_calibration_method, ) # Case 5: Invalid quantization type @@ -338,8 +331,8 @@ def test_get_qt_provider(self, mock_qt_rule): self.assertIsNone(WanPipeline.get_qt_provider(config_invalid)) # To test quantize_transformer, we patch its external dependencies - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') - @patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs") def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model): """ Tests that quantize_transformer calls qwix when quantization is enabled. @@ -370,14 +363,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize # Check that the model returned is the new quantized model self.assertIs(result, mock_quantized_model_obj) - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") def test_quantize_transformer_disabled(self, mock_quantize_model): """ Tests that quantize_transformer is skipped when quantization is disabled. """ # Setup Mocks mock_config = Mock(spec=HyperParameters) - mock_config.use_qwix_quantization = False # Main condition for this test + mock_config.use_qwix_quantization = False # Main condition for this test mock_model = Mock(spec=WanModel) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index cc8142159..2c6caf579 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -228,6 +228,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data per_device_tflops = self.calculate_tflops(pipeline) scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) + with ThreadPoolExecutor(max_workers=1) as executor: for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: @@ -300,7 +301,7 @@ def train_step(state, data, rng, scheduler_state, scheduler, config): def step_optimizer(state, data, rng, scheduler_state, scheduler, config): - _, new_rng, timestep_rng = jax.random.split(rng, num=3) + _, new_rng, timestep_rng, dropout_rng = jax.random.split(rng, num=4) for k, v in data.items(): data[k] = v[: config.global_batch_size_to_train_on, :] @@ -323,6 +324,8 @@ def loss_fn(params): hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, + deterministic=False, + rngs=nnx.Rngs(dropout_rng), ) training_target = scheduler.training_target(latents, noise, timesteps)