From fb126028d1003f417e77469b25865902debb6f66 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 25 Aug 2025 20:29:31 +0000 Subject: [PATCH 1/5] wip - add dropout change sharding --- src/maxdiffusion/configs/base_wan_14b.yml | 3 +- .../input_pipeline/_tfds_data_processing.py | 74 ++++++++++++------- .../input_pipeline_interface.py | 2 +- src/maxdiffusion/models/attention_flax.py | 23 ++++-- .../wan/transformers/transformer_wan.py | 45 +++++++---- .../pipelines/wan/wan_pipeline.py | 1 + .../tests/wan_transformer_test.py | 41 +++++----- src/maxdiffusion/trainers/wan_trainer.py | 7 +- 8 files changed, 121 insertions(+), 75 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 0d3fb969b..78399af27 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -56,8 +56,9 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring flash_min_seq_length: 4096 +dropout: 0.1 flash_block_sizes: {} # Use on v6e diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 885d59ef6..bb0428aa9 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -78,9 +78,18 @@ def make_tf_iterator( train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter + # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py def _make_tfrecord_iterator( - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description_fn, + prepare_sample_fn, + dataset_path, + is_training: bool, ): # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. @@ -93,10 +102,10 @@ def _make_tfrecord_iterator( # Determine whether to use the "cached" dataset, which requires externally # provided parsing functions, or the default one with its internal parsing logic. make_cached_tfrecord_iterator = ( - config.cache_latents_text_encoder_outputs - and is_dataset_dir_valid - and "load_tfrecord_cached" in config.get_keys() - and config.load_tfrecord_cached + config.cache_latents_text_encoder_outputs + and is_dataset_dir_valid + and "load_tfrecord_cached" in config.get_keys() + and config.load_tfrecord_cached ) feature_description = { @@ -121,42 +130,47 @@ def prepare_sample(features): if not is_training: num_eval_samples = 0 for _ in ds: - num_eval_samples += 1 + num_eval_samples += 1 remainder = num_eval_samples % global_batch_size if remainder != 0: - num_to_pad = global_batch_size - remainder - # Create a dataset of padding samples from the beginning - padding_ds = ds.take(num_to_pad) - # Add the padding samples to the end - ds = ds.concatenate(padding_ds) - max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") + num_to_pad = global_batch_size - remainder + # Create a dataset of padding samples from the beginning + padding_ds = ds.take(num_to_pad) + # Add the padding samples to the end + ds = ds.concatenate(padding_ds) + max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample ds = ( - ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(used_prepare_sample, num_parallel_calls=AUTOTUNE) + ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) + .map(used_prepare_sample, num_parallel_calls=AUTOTUNE) ) if is_training: ds = ( - ds.shuffle(global_batch_size * 10) - .batch(global_batch_size // dataloading_host_count, drop_remainder=True) - .repeat(-1) - .prefetch(AUTOTUNE) + ds.shuffle(global_batch_size * 10) + .batch(global_batch_size // dataloading_host_count, drop_remainder=True) + .repeat(-1) + .prefetch(AUTOTUNE) ) # For Evaluation else: - ds = ( - ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False) - .prefetch(AUTOTUNE) - ) + ds = ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False).prefetch(AUTOTUNE) iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh) return iter + def make_tfrecord_iterator( - config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn, + is_training, ): """Iterator for TFRecord format. For Laion dataset, check out preparation script @@ -165,4 +179,14 @@ def make_tfrecord_iterator( # Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset. # TODO: refactor to support evaluation on all dataset format. dataset_path = config.train_data_dir if is_training else config.eval_data_dir - return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training) + return _make_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + feature_description, + prepare_sample_fn, + dataset_path, + is_training, + ) diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index e7014bbc3..16477c35d 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -107,7 +107,7 @@ def make_data_iterator( global_batch_size, feature_description, prepare_sample_fn, - is_training + is_training, ) else: assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 25788fb69..f78e07576 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -734,7 +734,7 @@ def __init__( # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. # Dims are [num_blocks, embed, heads] - kernel_axes = (None, "embed", "heads") + kernel_axes = ("embed", None, "heads") qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) self.query = nnx.Linear( @@ -748,8 +748,8 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "embed", + "heads", ), ), ) @@ -765,8 +765,8 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "embed", + "heads", ), ), ) @@ -782,8 +782,8 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "embed", + "heads" ), ), ) @@ -792,12 +792,21 @@ def __init__( rngs=rngs, in_features=self.inner_dim, out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads", None)), dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + "embed", + None + ), + ), ) + self.drop_out = nnx.Dropout(dropout) + self.norm_q = None self.norm_k = None if qk_norm is not None: @@ -847,7 +856,8 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup return xq_out, xk_out def __call__( - self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None + self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None, + deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) @@ -877,6 +887,7 @@ def __call__( attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") hidden_states = self.proj_attn(attn_output) + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 6588929b1..7dde68b61 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -175,12 +175,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( + "embed", None, "mlp", - "embed", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -217,6 +216,8 @@ def __init__( else: raise NotImplementedError(f"{activation_fn} is not implemented.") + self.drop_out = nnx.Dropout(dropout) + self.proj_out = nnx.Linear( rngs=rngs, in_features=inner_dim, @@ -228,15 +229,16 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - None, "embed", "mlp", + None, ), ), ) - def __call__(self, hidden_states: jax.Array) -> jax.Array: + def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: hidden_states = self.act_fn(hidden_states) + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return self.proj_out(hidden_states) @@ -260,6 +262,7 @@ def __init__( weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, attention: str = "dot_product", + dropout: float = 0.0, ): # 1. Self-attention @@ -278,6 +281,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, attention_kernel=attention, + dropout=dropout ) # 1. Cross-attention @@ -295,6 +299,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, attention_kernel=attention, + dropout=dropout ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -308,13 +313,16 @@ def __init__( dtype=dtype, weights_dtype=weights_dtype, precision=precision, + dropout=dropout ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) key = rngs.params() - self.adaln_scale_shift_table = nnx.Param(jax.random.normal(key, (1, 6, dim)) / dim**0.5) + self.adaln_scale_shift_table = nnx.Param( + jax.random.normal(key, (1, 6, dim)) / dim**0.5, + sharding=("embed",)) - def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array): + def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None,): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb), 6, axis=1 ) @@ -324,18 +332,18 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb + hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb, deterministic=deterministic, rngs=rngs ) hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs) hidden_states = hidden_states + attn_output # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) - ff_output = self.ffn(norm_hidden_states) + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype) return hidden_states @@ -356,6 +364,7 @@ def __init__( freq_dim: int = 256, ffn_dim: int = 13824, num_layers: int = 40, + dropout: float = 0.0, cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, @@ -424,6 +433,7 @@ def init_block(rngs): weights_dtype=weights_dtype, precision=precision, attention=attention, + dropout=dropout, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -454,6 +464,8 @@ def __call__( encoder_hidden_states_image: Optional[jax.Array] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = True, + rngs: nnx.Rngs = None, ) -> Union[jax.Array, Dict[str, jax.Array]]: batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size @@ -476,20 +488,21 @@ def __call__( raise NotImplementedError("img2vid is not yet implemented.") def scan_fn(carry, block): - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states_carry, rngs_carry = carry + hidden_states = block(hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry) + new_carry = (hidden_states, rngs_carry) + return new_carry, None - initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) rematted_block_forward = self.gradient_checkpoint.apply(scan_fn) - final_carry = nnx.scan( + initial_carry = (hidden_states, rngs) + final_carry, _ = nnx.scan( rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), - out_axes=nnx.Carry, + out_axes=(nnx.Carry, 0), )(initial_carry, self.blocks) - hidden_states = final_carry[0] + hidden_states, _ = final_carry shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1659d3bb5..2a02679ff 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -82,6 +82,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["flash_block_sizes"] = get_flash_block_sizes(config) wan_config["remat_policy"] = config.remat_policy wan_config["flash_min_seq_length"] = config.flash_min_seq_length + wan_config["dropout"] = config.dropout # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index b6a73ee5d..84efa064e 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import jax import jax.numpy as jnp @@ -276,7 +277,7 @@ def test_wan_model(self): ) assert dummy_output.shape == hidden_states_shape - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule") def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. @@ -292,11 +293,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_int8.quantization = "int8" provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) - mock_qt_rule.assert_called_once_with( - module_path='.*', - weight_qtype=jnp.int8, - act_qtype=jnp.int8 - ) + mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8) # Case 3: Quantization enabled, type 'fp8' mock_qt_rule.reset_mock() @@ -305,11 +302,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8.quantization = "fp8" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - mock_qt_rule.assert_called_once_with( - module_path='.*', - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn - ) + mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn) # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() @@ -320,15 +313,15 @@ def test_get_qt_provider(self, mock_qt_rule): provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) mock_qt_rule.assert_called_once_with( - module_path='.*', # Apply to all modules - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method = config_fp8_full.quantization_calibration_method, - act_calibration_method = config_fp8_full.quantization_calibration_method, - bwd_calibration_method = config_fp8_full.quantization_calibration_method, + module_path=".*", # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config_fp8_full.quantization_calibration_method, + act_calibration_method=config_fp8_full.quantization_calibration_method, + bwd_calibration_method=config_fp8_full.quantization_calibration_method, ) # Case 5: Invalid quantization type @@ -338,8 +331,8 @@ def test_get_qt_provider(self, mock_qt_rule): self.assertIsNone(WanPipeline.get_qt_provider(config_invalid)) # To test quantize_transformer, we patch its external dependencies - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') - @patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs") def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model): """ Tests that quantize_transformer calls qwix when quantization is enabled. @@ -370,14 +363,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize # Check that the model returned is the new quantized model self.assertIs(result, mock_quantized_model_obj) - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") def test_quantize_transformer_disabled(self, mock_quantize_model): """ Tests that quantize_transformer is skipped when quantization is disabled. """ # Setup Mocks mock_config = Mock(spec=HyperParameters) - mock_config.use_qwix_quantization = False # Main condition for this test + mock_config.use_qwix_quantization = False # Main condition for this test mock_model = Mock(spec=WanModel) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 37615b076..55fd6f4cb 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -149,7 +149,7 @@ def start_training(self): pipeline = self.load_checkpoint() # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + #pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") # save some memory. del pipeline.vae @@ -227,6 +227,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) + with ThreadPoolExecutor(max_workers=1) as executor: for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: @@ -290,7 +291,7 @@ def train_step(state, data, rng, scheduler_state, scheduler, config): def step_optimizer(state, data, rng, scheduler_state, scheduler, config): - _, new_rng, timestep_rng = jax.random.split(rng, num=3) + _, new_rng, timestep_rng, dropout_rng = jax.random.split(rng, num=4) for k, v in data.items(): data[k] = v[: config.global_batch_size_to_train_on, :] @@ -313,6 +314,8 @@ def loss_fn(params): hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, + deterministic=False, + rngs=nnx.Rngs(dropout_rng) ) training_target = scheduler.training_target(latents, noise, timesteps) From 1ca1b1bc10e1fbdaa7592ae765f29cf71b0cfb5d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 26 Aug 2025 17:18:20 +0000 Subject: [PATCH 2/5] revert shardings --- src/maxdiffusion/models/attention_flax.py | 14 +++++++------- src/maxdiffusion/models/gradient_checkpoint.py | 13 ++++++++++--- .../models/wan/transformers/transformer_wan.py | 9 ++++----- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f78e07576..04919cc59 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -734,7 +734,7 @@ def __init__( # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. # Dims are [num_blocks, embed, heads] - kernel_axes = ("embed", None, "heads") + kernel_axes = (None, "embed", "heads") qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) self.query = nnx.Linear( @@ -748,8 +748,8 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( + None, "embed", - "heads", ), ), ) @@ -765,8 +765,8 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( + None, "embed", - "heads", ), ), ) @@ -782,8 +782,8 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( + None, "embed", - "heads" ), ), ) @@ -792,15 +792,15 @@ def __init__( rngs=rngs, in_features=self.inner_dim, out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads", None)), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")), dtype=dtype, param_dtype=weights_dtype, precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - "embed", - None + None, + "heads" ), ), ) diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 28f637c23..b317cc523 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -39,6 +39,7 @@ class GradientCheckpointType(Enum): NONE = auto() FULL = auto() MATMUL_WITHOUT_BATCH = auto() + OFFLOAD_MATMUL_WITHOUT_BATCH = auto() ATTN = auto() @classmethod @@ -65,10 +66,16 @@ def to_jax_policy(self): return SKIP_GRADIENT_CHECKPOINT_KEY case GradientCheckpointType.FULL: return None - case GradientCheckpointType.ATTN: - return cp.save_and_offload_only_these_names( - names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host" + case GradientCheckpointType.OFFLOAD_MATMUL_WITHOUT_BATCH: + return cp.offload_dot_with_no_batch_dims( + offload_src="device", offload_dst="pinned_host" ) + case GradientCheckpointType.ATTN: + offload_policy = cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], names_which_can_be_offloaded=["attn_output"], offload_src="device", offload_dst="pinned_host" + ) + policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + return cp.save_from_both_policies(offload_policy, policy) case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 7dde68b61..fa6affc01 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -175,11 +175,12 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "embed", None, "mlp", + "embed", ), ), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -217,7 +218,6 @@ def __init__( raise NotImplementedError(f"{activation_fn} is not implemented.") self.drop_out = nnx.Dropout(dropout) - self.proj_out = nnx.Linear( rngs=rngs, in_features=inner_dim, @@ -229,9 +229,9 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( + None, "embed", "mlp", - None, ), ), ) @@ -319,8 +319,7 @@ def __init__( key = rngs.params() self.adaln_scale_shift_table = nnx.Param( - jax.random.normal(key, (1, 6, dim)) / dim**0.5, - sharding=("embed",)) + jax.random.normal(key, (1, 6, dim)) / dim**0.5,) def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None,): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 55fd6f4cb..7a03b2d2d 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -149,7 +149,7 @@ def start_training(self): pipeline = self.load_checkpoint() # Generate a sample before training to compare against generated sample after training. - #pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") # save some memory. del pipeline.vae From 72892187b7f6d11b8d70bc48b649a2dcc7146c0a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 29 Aug 2025 17:39:55 +0000 Subject: [PATCH 3/5] activation offloading --- src/maxdiffusion/models/attention_flax.py | 2 +- .../models/gradient_checkpoint.py | 21 ++++++++++++++----- .../wan/transformers/transformer_wan.py | 10 +++++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 04919cc59..b1ff1a534 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -800,7 +800,7 @@ def __init__( nnx.initializers.zeros, ( None, - "heads" + "heads", ), ), ) diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index b317cc523..86bce9c81 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -71,15 +71,25 @@ def to_jax_policy(self): offload_src="device", offload_dst="pinned_host" ) case GradientCheckpointType.ATTN: - offload_policy = cp.save_and_offload_only_these_names( - names_which_can_be_saved=[], names_which_can_be_offloaded=["attn_output"], offload_src="device", offload_dst="pinned_host" + policy = cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=[ + #"attn_output", + #"query_proj", + #"key_proj", + #"value_proj", + #"xq_out", + #"xk_out", + "ffn_activation" + ], + offload_src="device", + offload_dst="pinned_host" ) - policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - return cp.save_from_both_policies(offload_policy, policy) + return policy case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nnx.Module) -> nnx.Module: + def apply(self, module: nnx.Module, static_argnums=()) -> nnx.Module: """ Applies a gradient checkpoint policy to a module if no policy is needed, it will return the module as is @@ -97,4 +107,5 @@ def apply(self, module: nnx.Module) -> nnx.Module: module, prevent_cse=False, policy=policy, + static_argnums=static_argnums ) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index fa6affc01..f48635a45 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from jax.sharding import PartitionSpec +from jax.ad_checkpoint import checkpoint_name from flax import nnx import numpy as np from .... import common_types @@ -42,7 +43,7 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): t_dim = attention_head_dim - h_dim - w_dim freqs = [] for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False) + freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=False) freqs.append(freq) freqs = jnp.concatenate(freqs, axis=1) t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) @@ -180,7 +181,7 @@ def __init__( "embed", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -237,9 +238,10 @@ def __init__( ) def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - hidden_states = self.act_fn(hidden_states) + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = checkpoint_name(hidden_states, "ffn_activation") hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - return self.proj_out(hidden_states) + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): From 0d9f868a791905cede431b9c44140cad84bc2520 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 2 Sep 2025 18:05:58 +0000 Subject: [PATCH 4/5] add gc tests. add metadata axis to scan. --- src/maxdiffusion/configs/base_wan_14b.yml | 8 ++- src/maxdiffusion/models/attention_flax.py | 10 +--- .../models/gradient_checkpoint.py | 22 +++---- .../wan/transformers/transformer_wan.py | 12 ++-- .../pipelines/wan/wan_pipeline.py | 2 + src/maxdiffusion/pyconfig.py | 2 + .../tests/gradient_checkpoint_test.py | 59 +++++++++++++++++++ 7 files changed, 86 insertions(+), 29 deletions(-) create mode 100644 src/maxdiffusion/tests/gradient_checkpoint_test.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 78399af27..a76add01d 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -194,8 +194,14 @@ enable_data_shuffling: True # FULL - means full gradient checkpoint, whenever possible (minimum memory usage) # MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, # except for ones that involve batch dimension - that means that all attention and projection -# layers will have gradient checkpoint, but not the backward with respect to the parameters +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b1ff1a534..8cfe32e71 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -734,7 +734,7 @@ def __init__( # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. # Dims are [num_blocks, embed, heads] - kernel_axes = (None, "embed", "heads") + kernel_axes = ("embed", "heads") qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) self.query = nnx.Linear( @@ -748,7 +748,6 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "embed", ), ), @@ -765,7 +764,6 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "embed", ), ), @@ -782,7 +780,6 @@ def __init__( bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "embed", ), ), @@ -792,14 +789,13 @@ def __init__( rngs=rngs, in_features=self.inner_dim, out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), dtype=dtype, param_dtype=weights_dtype, precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, ( - None, "heads", ), ), @@ -818,7 +814,6 @@ def __init__( scale_init=nnx.with_partitioning( nnx.initializers.ones, ( - None, "norm", ), ), @@ -832,7 +827,6 @@ def __init__( scale_init=nnx.with_partitioning( nnx.initializers.ones, ( - None, "norm", ), ), diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 86bce9c81..d1286e0ca 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -40,7 +40,7 @@ class GradientCheckpointType(Enum): FULL = auto() MATMUL_WITHOUT_BATCH = auto() OFFLOAD_MATMUL_WITHOUT_BATCH = auto() - ATTN = auto() + CUSTOM = auto() @classmethod def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": @@ -57,7 +57,7 @@ def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": s = "none" return GradientCheckpointType[s.upper()] - def to_jax_policy(self): + def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = []): """ Converts the gradient checkpoint type to a jax policy """ @@ -70,18 +70,10 @@ def to_jax_policy(self): return cp.offload_dot_with_no_batch_dims( offload_src="device", offload_dst="pinned_host" ) - case GradientCheckpointType.ATTN: + case GradientCheckpointType.CUSTOM: policy = cp.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=[ - #"attn_output", - #"query_proj", - #"key_proj", - #"value_proj", - #"xq_out", - #"xk_out", - "ffn_activation" - ], + names_which_can_be_saved=names_which_can_be_saved, + names_which_can_be_offloaded=names_which_can_be_offloaded, offload_src="device", offload_dst="pinned_host" ) @@ -89,7 +81,7 @@ def to_jax_policy(self): case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nnx.Module, static_argnums=()) -> nnx.Module: + def apply(self, module: nnx.Module, names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], static_argnums=()) -> nnx.Module: """ Applies a gradient checkpoint policy to a module if no policy is needed, it will return the module as is @@ -100,7 +92,7 @@ def apply(self, module: nnx.Module, static_argnums=()) -> nnx.Module: Returns: nn.Module: the module with the policy applied """ - policy = self.to_jax_policy() + policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded) if policy == SKIP_GRADIENT_CHECKPOINT_KEY: return module return nnx.remat( # pylint: disable=invalid-name diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index f48635a45..57bd6a2d6 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -176,12 +176,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - None, "mlp", "embed", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -230,7 +229,6 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - None, "embed", "mlp", ), @@ -381,6 +379,8 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", remat_policy: str = "None", + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [] ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -417,7 +417,7 @@ def __init__( # 3. Transformer blocks @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=0, out_axes=0) + @nnx.vmap(in_axes=0, out_axes=0, transform_metadata= {nnx.PARTITION_NAME: "layers_per_stage"} ) def init_block(rngs): return WanTransformerBlock( rngs=rngs, @@ -438,6 +438,8 @@ def init_block(rngs): ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + self.names_which_can_be_offloaded = names_which_can_be_offloaded + self.names_which_can_be_saved = names_which_can_be_saved self.blocks = init_block(rngs) @@ -494,7 +496,7 @@ def scan_fn(carry, block): new_carry = (hidden_states, rngs_carry) return new_carry, None - rematted_block_forward = self.gradient_checkpoint.apply(scan_fn) + rematted_block_forward = self.gradient_checkpoint.apply(scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded) initial_carry = (hidden_states, rngs) final_carry, _ = nnx.scan( rematted_block_forward, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 2a02679ff..5fc3a5f42 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -81,6 +81,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["precision"] = get_precision(config) wan_config["flash_block_sizes"] = get_flash_block_sizes(config) wan_config["remat_policy"] = config.remat_policy + wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved + wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded wan_config["flash_min_seq_length"] = config.flash_min_seq_length wan_config["dropout"] = config.dropout diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 33fc62f83..3eb4a23be 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -120,6 +120,8 @@ def _load_kwargs(self, argv: list[str]): @staticmethod def wan_init(raw_keys): + if not any("layers_per_stage" in inner_tuple for inner_tuple in raw_keys["logical_axis_rules"]): + raw_keys["logical_axis_rules"]+= (("layers_per_stage", None),) if "wan_transformer_pretrained_model_name_or_path" in raw_keys: transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] if transformer_pretrained_model_name_or_path == "": diff --git a/src/maxdiffusion/tests/gradient_checkpoint_test.py b/src/maxdiffusion/tests/gradient_checkpoint_test.py new file mode 100644 index 000000000..049308e7c --- /dev/null +++ b/src/maxdiffusion/tests/gradient_checkpoint_test.py @@ -0,0 +1,59 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest +from types import SimpleNamespace +from absl.testing import absltest + +import jax + +from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType + +class GradientCheckpointTest(unittest.TestCase): + """Unit test suite for GradientCheckpointType policies.""" + + def test_none_policy(self): + policy = GradientCheckpointType.from_str("NONE") + self.assertEqual(policy.to_jax_policy(), "skip") + + def test_full_policy(self): + policy = GradientCheckpointType.from_str("FULL") + self.assertIsNone(policy.to_jax_policy()) + + def test_matmul_without_batch_policy(self): + policy = GradientCheckpointType.from_str("MATMUL_WITHOUT_BATCH") + jax_policy_fn = policy.to_jax_policy() + self.assertIs(jax_policy_fn, jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims) + + def test_offload_matmul_without_batch_policy(self): + """ + Tests the offload variant by checking the class name of the return value. + """ + policy = GradientCheckpointType.from_str("OFFLOAD_MATMUL_WITHOUT_BATCH") + jax_policy_fn = policy.to_jax_policy() + self.assertTrue(callable(jax_policy_fn)) + + def test_custom_policy(self): + """ + Tests the custom policy by checking the class name of the return value. + """ + policy = GradientCheckpointType.from_str("CUSTOM") + names_to_offload = ["attn_output"] + jax_policy_fn = policy.to_jax_policy(names_which_can_be_offloaded=names_to_offload) + self.assertTrue(callable(jax_policy_fn)) + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 4e362c39f17a5ab46938101f6f3dfddb3cfaa862 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 2 Sep 2025 18:21:13 +0000 Subject: [PATCH 5/5] merge and lint --- src/maxdiffusion/configuration_utils.py | 43 ++++++------ src/maxdiffusion/generate_wan.py | 53 ++++++++------- src/maxdiffusion/models/attention_flax.py | 32 ++++----- .../models/gradient_checkpoint.py | 29 ++++---- .../wan/transformers/transformer_wan.py | 41 ++++++++---- src/maxdiffusion/pyconfig.py | 2 +- .../tests/gradient_checkpoint_test.py | 67 ++++++++++--------- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 8 files changed, 143 insertions(+), 126 deletions(-) diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 5d1785070..8463ebaa1 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -47,21 +47,24 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") + class CustomEncoder(json.JSONEncoder): - """ - Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes. - """ - def default(self, o): - # This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16" - if isinstance(o, type(jnp.dtype('bfloat16'))): - return str(o) - # Add fallbacks for other numpy types if needed - if isinstance(o, np.integer): - return int(o) - if isinstance(o, np.floating): - return float(o) - # Let the base class default method raise the TypeError for other types - return super().default(o) + """ + Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes. + """ + + def default(self, o): + # This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16" + if isinstance(o, type(jnp.dtype("bfloat16"))): + return str(o) + # Add fallbacks for other numpy types if needed + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + # Let the base class default method raise the TypeError for other types + return super().default(o) + class FrozenDict(OrderedDict): @@ -596,14 +599,14 @@ def to_json_saveable(value): config_dict.pop("quant", None) keys_to_remove = [] for key, value in config_dict.items(): - # Check the type of the value by its class name to avoid import issues - if type(value).__name__ == 'Rngs': - keys_to_remove.append(key) + # Check the type of the value by its class name to avoid import issues + if type(value).__name__ == "Rngs": + keys_to_remove.append(key) if keys_to_remove: - max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") - for key in keys_to_remove: - config_dict.pop(key) + max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}") + for key in keys_to_remove: + config_dict.pop(key) try: diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 1dc1789a1..3530d5eb0 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -22,43 +22,47 @@ from maxdiffusion.utils import export_to_video from google.cloud import storage + def upload_video_to_gcs(output_dir: str, video_path: str): - """ - Uploads a local video file to a specified Google Cloud Storage bucket. - """ - try: - path_without_scheme = output_dir.removeprefix("gs://") - parts = path_without_scheme.split('/', 1) - bucket_name = parts[0] - folder_name = parts[1] if len(parts) > 1 else '' + """ + Uploads a local video file to a specified Google Cloud Storage bucket. + """ + try: + path_without_scheme = output_dir.removeprefix("gs://") + parts = path_without_scheme.split("/", 1) + bucket_name = parts[0] + folder_name = parts[1] if len(parts) > 1 else "" - storage_client = storage.Client() - bucket = storage_client.bucket(bucket_name) + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) - source_file_path = f"./{video_path}" - destination_blob_name = os.path.join(folder_name, "videos", video_path) + source_file_path = f"./{video_path}" + destination_blob_name = os.path.join(folder_name, "videos", video_path) - blob = bucket.blob(destination_blob_name) + blob = bucket.blob(destination_blob_name) - max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...") - blob.upload_from_filename(source_file_path) - max_logging.log(f"Upload complete {source_file_path}.") + max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...") + blob.upload_from_filename(source_file_path) + max_logging.log(f"Upload complete {source_file_path}.") + + except Exception as e: + max_logging.log(f"An error occurred: {e}") - except Exception as e: - max_logging.log(f"An error occurred: {e}") def delete_file(file_path: str): if os.path.exists(file_path): - try: - os.remove(file_path) - max_logging.log(f"Successfully deleted file: {file_path}") - except OSError as e: - max_logging.log(f"Error deleting file '{file_path}': {e}") + try: + os.remove(file_path) + max_logging.log(f"Successfully deleted file: {file_path}") + except OSError as e: + max_logging.log(f"Error deleting file '{file_path}': {e}") else: - max_logging.log(f"The file '{file_path}' does not exist.") + max_logging.log(f"The file '{file_path}' does not exist.") + jax.config.update("jax_use_shardy_partitioner", True) + def inference_generate_video(config, pipeline, filename_prefix=""): s0 = time.perf_counter() prompt = [config.prompt] * config.global_batch_size_to_train_on @@ -88,6 +92,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""): delete_file(f"./{video_path}") return + def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 8cfe32e71..3cbb0ccea 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -747,9 +747,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - "embed", - ), + ("embed",), ), ) @@ -763,9 +761,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - "embed", - ), + ("embed",), ), ) @@ -779,9 +775,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - "embed", - ), + ("embed",), ), ) @@ -795,9 +789,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ( - "heads", - ), + ("heads",), ), ) @@ -813,9 +805,7 @@ def __init__( dtype=dtype, scale_init=nnx.with_partitioning( nnx.initializers.ones, - ( - "norm", - ), + ("norm",), ), param_dtype=weights_dtype, ) @@ -826,9 +816,7 @@ def __init__( dtype=dtype, scale_init=nnx.with_partitioning( nnx.initializers.ones, - ( - "norm", - ), + ("norm",), ), param_dtype=weights_dtype, ) @@ -850,8 +838,12 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup return xq_out, xk_out def __call__( - self, hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None, - deterministic: bool = True, rngs: nnx.Rngs = None, + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array = None, + rotary_emb: Optional[jax.Array] = None, + deterministic: bool = True, + rngs: nnx.Rngs = None, ) -> jax.Array: hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index d1286e0ca..a111ef717 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -67,21 +67,25 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_ case GradientCheckpointType.FULL: return None case GradientCheckpointType.OFFLOAD_MATMUL_WITHOUT_BATCH: - return cp.offload_dot_with_no_batch_dims( - offload_src="device", offload_dst="pinned_host" - ) + return cp.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") case GradientCheckpointType.CUSTOM: policy = cp.save_and_offload_only_these_names( - names_which_can_be_saved=names_which_can_be_saved, - names_which_can_be_offloaded=names_which_can_be_offloaded, - offload_src="device", - offload_dst="pinned_host" - ) + names_which_can_be_saved=names_which_can_be_saved, + names_which_can_be_offloaded=names_which_can_be_offloaded, + offload_src="device", + offload_dst="pinned_host", + ) return policy case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - def apply(self, module: nnx.Module, names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], static_argnums=()) -> nnx.Module: + def apply( + self, + module: nnx.Module, + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + static_argnums=(), + ) -> nnx.Module: """ Applies a gradient checkpoint policy to a module if no policy is needed, it will return the module as is @@ -95,9 +99,4 @@ def apply(self, module: nnx.Module, names_which_can_be_saved: list = [], names_w policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded) if policy == SKIP_GRADIENT_CHECKPOINT_KEY: return module - return nnx.remat( # pylint: disable=invalid-name - module, - prevent_cse=False, - policy=policy, - static_argnums=static_argnums - ) + return nnx.remat(module, prevent_cse=False, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 57bd6a2d6..718b5015e 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -236,10 +236,10 @@ def __init__( ) def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) hidden_states = checkpoint_name(hidden_states, "ffn_activation") hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - return self.proj_out(hidden_states) # output is (4, 75600, 5120) + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): @@ -281,7 +281,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, attention_kernel=attention, - dropout=dropout + dropout=dropout, ) # 1. Cross-attention @@ -299,7 +299,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, attention_kernel=attention, - dropout=dropout + dropout=dropout, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -313,15 +313,24 @@ def __init__( dtype=dtype, weights_dtype=weights_dtype, precision=precision, - dropout=dropout + dropout=dropout, ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) key = rngs.params() self.adaln_scale_shift_table = nnx.Param( - jax.random.normal(key, (1, 6, dim)) / dim**0.5,) + jax.random.normal(key, (1, 6, dim)) / dim**0.5, + ) - def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None,): + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + temb: jax.Array, + rotary_emb: jax.Array, + deterministic: bool = True, + rngs: nnx.Rngs = None, + ): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb), 6, axis=1 ) @@ -331,13 +340,19 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb, deterministic=deterministic, rngs=rngs + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, ) hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs) + attn_output = self.attn2( + hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs + ) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -380,7 +395,7 @@ def __init__( attention: str = "dot_product", remat_policy: str = "None", names_which_can_be_saved: list = [], - names_which_can_be_offloaded: list = [] + names_which_can_be_offloaded: list = [], ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -417,7 +432,7 @@ def __init__( # 3. Transformer blocks @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=0, out_axes=0, transform_metadata= {nnx.PARTITION_NAME: "layers_per_stage"} ) + @nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}) def init_block(rngs): return WanTransformerBlock( rngs=rngs, @@ -496,7 +511,9 @@ def scan_fn(carry, block): new_carry = (hidden_states, rngs_carry) return new_carry, None - rematted_block_forward = self.gradient_checkpoint.apply(scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded) + rematted_block_forward = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded + ) initial_carry = (hidden_states, rngs) final_carry, _ = nnx.scan( rematted_block_forward, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3eb4a23be..3bb5bd13c 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -121,7 +121,7 @@ def _load_kwargs(self, argv: list[str]): @staticmethod def wan_init(raw_keys): if not any("layers_per_stage" in inner_tuple for inner_tuple in raw_keys["logical_axis_rules"]): - raw_keys["logical_axis_rules"]+= (("layers_per_stage", None),) + raw_keys["logical_axis_rules"] += (("layers_per_stage", None),) if "wan_transformer_pretrained_model_name_or_path" in raw_keys: transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] if transformer_pretrained_model_name_or_path == "": diff --git a/src/maxdiffusion/tests/gradient_checkpoint_test.py b/src/maxdiffusion/tests/gradient_checkpoint_test.py index 049308e7c..ca237d523 100644 --- a/src/maxdiffusion/tests/gradient_checkpoint_test.py +++ b/src/maxdiffusion/tests/gradient_checkpoint_test.py @@ -15,45 +15,46 @@ """ import unittest -from types import SimpleNamespace from absl.testing import absltest import jax from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType + class GradientCheckpointTest(unittest.TestCase): - """Unit test suite for GradientCheckpointType policies.""" - - def test_none_policy(self): - policy = GradientCheckpointType.from_str("NONE") - self.assertEqual(policy.to_jax_policy(), "skip") - - def test_full_policy(self): - policy = GradientCheckpointType.from_str("FULL") - self.assertIsNone(policy.to_jax_policy()) - - def test_matmul_without_batch_policy(self): - policy = GradientCheckpointType.from_str("MATMUL_WITHOUT_BATCH") - jax_policy_fn = policy.to_jax_policy() - self.assertIs(jax_policy_fn, jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims) - - def test_offload_matmul_without_batch_policy(self): - """ - Tests the offload variant by checking the class name of the return value. - """ - policy = GradientCheckpointType.from_str("OFFLOAD_MATMUL_WITHOUT_BATCH") - jax_policy_fn = policy.to_jax_policy() - self.assertTrue(callable(jax_policy_fn)) - - def test_custom_policy(self): - """ - Tests the custom policy by checking the class name of the return value. - """ - policy = GradientCheckpointType.from_str("CUSTOM") - names_to_offload = ["attn_output"] - jax_policy_fn = policy.to_jax_policy(names_which_can_be_offloaded=names_to_offload) - self.assertTrue(callable(jax_policy_fn)) + """Unit test suite for GradientCheckpointType policies.""" + + def test_none_policy(self): + policy = GradientCheckpointType.from_str("NONE") + self.assertEqual(policy.to_jax_policy(), "skip") + + def test_full_policy(self): + policy = GradientCheckpointType.from_str("FULL") + self.assertIsNone(policy.to_jax_policy()) + + def test_matmul_without_batch_policy(self): + policy = GradientCheckpointType.from_str("MATMUL_WITHOUT_BATCH") + jax_policy_fn = policy.to_jax_policy() + self.assertIs(jax_policy_fn, jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims) + + def test_offload_matmul_without_batch_policy(self): + """ + Tests the offload variant by checking the class name of the return value. + """ + policy = GradientCheckpointType.from_str("OFFLOAD_MATMUL_WITHOUT_BATCH") + jax_policy_fn = policy.to_jax_policy() + self.assertTrue(callable(jax_policy_fn)) + + def test_custom_policy(self): + """ + Tests the custom policy by checking the class name of the return value. + """ + policy = GradientCheckpointType.from_str("CUSTOM") + names_to_offload = ["attn_output"] + jax_policy_fn = policy.to_jax_policy(names_which_can_be_offloaded=names_to_offload) + self.assertTrue(callable(jax_policy_fn)) + if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index c7aec90e1..2c6caf579 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -325,7 +325,7 @@ def loss_fn(params): timestep=timesteps, encoder_hidden_states=encoder_hidden_states, deterministic=False, - rngs=nnx.Rngs(dropout_rng) + rngs=nnx.Rngs(dropout_rng), ) training_target = scheduler.training_target(latents, noise, timesteps)