diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 40a76c6cf..ab764366b 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -40,6 +40,9 @@ weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" # fp32 activations and fp32 weights with HIGHEST will provide the best precision diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 96b60426d..958cff916 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -221,6 +221,10 @@ def walk_and_upload_blobs(config, output_dir): def device_put_replicated(x, sharding): + """ + Although the name indiciates replication, this function can be used + to also shard an array based on sharding. + """ return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index]) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index fe86e08c4..fcdb7cf65 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -166,20 +166,25 @@ def _tpu_flash_attention( dtype: jnp.dtype = jnp.float32, ) -> jax.Array: """TPU Flash Attention""" - - max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + # Cross-attention where kv dims are much smaller due to encoder_hidden_states. + # If kv seq_len is padded too much, it causes issues in attention calculations. + if key.shape[1] != query.shape[1]: + kv_max_block_size = key.shape[1] + else: + kv_max_block_size = q_max_block_size if flash_block_sizes: block_sizes = flash_block_sizes else: block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(max_block_size, query.shape[2]), - block_kv_compute=min(max_block_size, key.shape[2]), - block_kv=min(max_block_size, key.shape[2]), - block_q_dkv=min(max_block_size, query.shape[2]), - block_kv_dkv=min(max_block_size, key.shape[2]), - block_kv_dkv_compute=min(max_block_size, query.shape[2]), - block_q_dq=min(max_block_size, query.shape[2]), - block_kv_dq=min(max_block_size, query.shape[2]), + block_q=min(q_max_block_size, query.shape[2]), + block_kv_compute=min(kv_max_block_size, key.shape[2]), + block_kv=min(kv_max_block_size, key.shape[2]), + block_q_dkv=min(q_max_block_size, query.shape[2]), + block_kv_dkv=min(kv_max_block_size, key.shape[2]), + block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), + block_q_dq=min(q_max_block_size, query.shape[2]), + block_kv_dq=min(kv_max_block_size, query.shape[2]), ) num_fsdp_shards = mesh.shape["fsdp"] diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3958f5240..1659d3bb5 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -220,6 +220,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value + if config.replicate_vae: + sharding = NamedSharding(mesh, P()) state[path].value = device_put_replicated(val, sharding) state = nnx.from_flat_state(state) @@ -231,11 +233,11 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): def get_basic_config(cls, dtype): rules = [ qwix.QtRule( - module_path='.*', # Apply to all modules - weight_qtype=dtype, - act_qtype=dtype, + module_path=".*", # Apply to all modules + weight_qtype=dtype, + act_qtype=dtype, ) - ] + ] return rules @classmethod @@ -247,17 +249,17 @@ def get_fp8_config(cls, quantization_calibration_method: str): """ rules = [ qwix.QtRule( - module_path='.*', # Apply to all modules - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method = quantization_calibration_method, - act_calibration_method = quantization_calibration_method, - bwd_calibration_method = quantization_calibration_method, + module_path=".*", # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=quantization_calibration_method, + act_calibration_method=quantization_calibration_method, + bwd_calibration_method=quantization_calibration_method, ) - ] + ] return rules @classmethod @@ -286,7 +288,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline batch_size = int(config.per_device_batch_size * jax.local_device_count()) latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) - model_inputs= (latents, timesteps, prompt_embeds) + model_inputs = (latents, timesteps, prompt_embeds) with mesh: quantized_model = qwix.quantize_model(model, q_rules, *model_inputs) max_logging.log("Qwix Quantization complete.") diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 150196bbe..182a427bb 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -142,7 +142,9 @@ def wan_init(raw_keys): if "quantization" not in raw_keys: raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.") elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]: - raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}") + raise ValueError( + f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}" + ) @staticmethod def calculate_global_batch_sizes(per_device_batch_size): diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index b6a73ee5d..eddf8bf94 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ + +from qwix import QtProvider import os import jax import jax.numpy as jnp @@ -290,7 +292,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_int8 = Mock(spec=HyperParameters) config_int8.use_qwix_quantization = True config_int8.quantization = "int8" - provider_int8 = WanPipeline.get_qt_provider(config_int8) + provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) mock_qt_rule.assert_called_once_with( module_path='.*', @@ -305,11 +307,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8.quantization = "fp8" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - mock_qt_rule.assert_called_once_with( - module_path='.*', - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn - ) + self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn) # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() @@ -319,17 +317,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8_full.quantization_calibration_method = "absmax" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) - mock_qt_rule.assert_called_once_with( - module_path='.*', # Apply to all modules - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method = config_fp8_full.quantization_calibration_method, - act_calibration_method = config_fp8_full.quantization_calibration_method, - bwd_calibration_method = config_fp8_full.quantization_calibration_method, - ) + self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2) # Case 5: Invalid quantization type config_invalid = Mock(spec=HyperParameters) @@ -338,8 +326,8 @@ def test_get_qt_provider(self, mock_qt_rule): self.assertIsNone(WanPipeline.get_qt_provider(config_invalid)) # To test quantize_transformer, we patch its external dependencies - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') - @patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs") def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model): """ Tests that quantize_transformer calls qwix when quantization is enabled. @@ -370,14 +358,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize # Check that the model returned is the new quantized model self.assertIs(result, mock_quantized_model_obj) - @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') + @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") def test_quantize_transformer_disabled(self, mock_quantize_model): """ Tests that quantize_transformer is skipped when quantization is disabled. """ # Setup Mocks mock_config = Mock(spec=HyperParameters) - mock_config.use_qwix_quantization = False # Main condition for this test + mock_config.use_qwix_quantization = False # Main condition for this test mock_model = Mock(spec=WanModel) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 3b0b520bf..ff4207522 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -36,6 +36,7 @@ from maxdiffusion.utils import load_video from skimage.metrics import structural_similarity as ssim from flax.training import train_state +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline class TrainState(train_state.TrainState): @@ -53,6 +54,12 @@ def generate_sample(config, pipeline, filename_prefix): """ Generates a video to validate training did not corrupt the model """ + if not hasattr(pipeline, "vae"): + wan_vae, vae_cache = WanPipeline.load_vae( + pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config + ) + pipeline.vae = wan_vae + pipeline.vae_cache = vae_cache return generate_wan(config, pipeline, filename_prefix) @@ -140,10 +147,13 @@ def prepare_sample(features): def start_training(self): pipeline = self.load_checkpoint() - # del pipeline.vae - # Generate a sample before training to compare against generated sample after training. pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + + # save some memory. + del pipeline.vae + del pipeline.vae_cache + mesh = pipeline.mesh data_iterator = self.load_dataset(mesh)