diff --git a/requirements.txt b/requirements.txt index b66f58234..d51653fa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,5 @@ sentencepiece aqtp imageio==2.37.0 imageio-ffmpeg==0.6.0 -hf_transfer>=0.1.9 \ No newline at end of file +hf_transfer>=0.1.9 +qwix@git+https://github.com/google/qwix.git \ No newline at end of file diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index ea42e1cd4..955a5e76f 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -35,4 +35,5 @@ sentencepiece aqtp imageio==2.37.0 imageio-ffmpeg==0.6.0 -hf_transfer>=0.1.9 \ No newline at end of file +hf_transfer>=0.1.9 +qwix@git+https://github.com/google/qwix.git \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b552b0621..40a76c6cf 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -287,4 +287,7 @@ quantization: '' # Shard the range finding operation for quantization. By default this is set to number of slices. quantization_local_shard_count: -1 compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8d2f2cd3b..696078bec 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -33,9 +33,11 @@ from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState from transformers import AutoTokenizer, UMT5EncoderModel from maxdiffusion.utils.import_utils import is_ftfy_available +from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs import html import re import torch +import qwix def basic_clean(text): @@ -225,6 +227,71 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): vae_cache = AutoencoderKLWanCache(wan_vae) return wan_vae, vae_cache + @classmethod + def get_basic_config(cls, dtype): + rules = [ + qwix.QtRule( + module_path='.*', # Apply to all modules + weight_qtype=dtype, + act_qtype=dtype, + ) + ] + return rules + + @classmethod + def get_fp8_config(cls, quantization_calibration_method: str): + """ + fp8 config rules with per-tensor calibration. + FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api): + The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice. + """ + rules = [ + qwix.QtRule( + module_path='.*', # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method = quantization_calibration_method, + act_calibration_method = quantization_calibration_method, + bwd_calibration_method = quantization_calibration_method, + ) + ] + return rules + + @classmethod + def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: + """Get quantization rules based on the config.""" + if not getattr(config, "use_qwix_quantization", False): + return None + + quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax") + match config.quantization: + case "int8": + return qwix.QtProvider(cls.get_basic_config(jnp.int8)) + case "fp8": + return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn)) + case "fp8_full": + return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method)) + return None + + @classmethod + def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): + """Quantizes the transformer model.""" + q_rules = cls.get_qt_provider(config) + if not q_rules: + return model + max_logging.log("Quantizing transformer with Qwix.") + + batch_size = int(config.per_device_batch_size * jax.local_device_count()) + latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) + model_inputs= (latents, timesteps, prompt_embeds) + with mesh: + quantized_model = qwix.quantize_model(model, q_rules, *model_inputs) + max_logging.log("Qwix Quantization complete.") + return quantized_model + @classmethod def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): with mesh: @@ -264,7 +331,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform with mesh: wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - return WanPipeline( + pipeline = WanPipeline( tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, @@ -277,6 +344,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform config=config, ) + pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh) + return pipeline + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 8e758d661..501294e28 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -136,6 +136,14 @@ def wan_init(raw_keys): ) else: raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + if "use_qwix_quantization" not in raw_keys: + raise ValueError(f"use_qwix_quantization is not set.") + elif raw_keys["use_qwix_quantization"]: + if "quantization" not in raw_keys: + raise ValueError(f"Quantization type is not set when use_qwix_quantization is enabled.") + elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]: + raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}") + @staticmethod def calculate_global_batch_sizes(per_device_batch_size): diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 01d169b01..f25c24c18 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -19,6 +19,7 @@ import jax.numpy as jnp import pytest import unittest +from unittest.mock import Mock, patch, MagicMock from absl.testing import absltest from flax import nnx from jax.sharding import Mesh @@ -34,6 +35,9 @@ from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection from ..models.normalization_flax import FP32LayerNorm from ..models.attention_flax import FlaxWanAttention +from maxdiffusion.pyconfig import HyperParameters +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline + IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -272,6 +276,96 @@ def test_wan_model(self): hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states ) assert dummy_output.shape == hidden_states_shape + + def test_get_qt_provider(self): + """ + Tests the provider logic for all config branches. + """ + # Case 1: Quantization disabled + config_disabled = Mock(spec=HyperParameters) + config_disabled.use_qwix_quantization = False + self.assertIsNone(WanPipeline.get_qt_provider(config_disabled)) + + # Case 2: Quantization enabled, type 'int8' + config_int8 = Mock(spec=HyperParameters) + config_int8.use_qwix_quantization = True + config_int8.quantization = "int8" + provider_int8 = WanPipeline.get_qt_provider(config_int8) + self.assertIsNotNone(provider_int8) + self.assertEqual(provider_int8.rules[0].kwargs['weight_qtype'], jnp.int8) + + # Case 3: Quantization enabled, type 'fp8' + config_fp8 = Mock(spec=HyperParameters) + config_fp8.use_qwix_quantization = True + config_fp8.quantization = "fp8" + provider_fp8 = WanPipeline.get_qt_provider(config_fp8) + self.assertIsNotNone(provider_fp8) + self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn) + + # Case 4: Quantization enabled, type 'fp8_full' + config_fp8_full = Mock(spec=HyperParameters) + config_fp8_full.use_qwix_quantization = True + config_fp8_full.quantization = "fp8_full" + config_fp8_full.quantization_calibration_method = "absmax" + provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) + self.assertIsNotNone(provider_fp8_full) + self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2) + + # Case 5: Invalid quantization type + config_invalid = Mock(spec=HyperParameters) + config_invalid.use_qwix_quantization = True + config_invalid.quantization = "invalid_type" + self.assertIsNone(WanPipeline.get_qt_provider(config_invalid)) + + # To test quantize_transformer, we patch its external dependencies + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') + @patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs') + def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model): + """ + Tests that quantize_transformer calls qwix when quantization is enabled. + """ + # Setup Mocks + mock_config = Mock(spec=HyperParameters) + mock_config.use_qwix_quantization = True + mock_config.quantization = "fp8_full" + mock_config.per_device_batch_size = 1 + + mock_model = Mock(spec=WanModel) + mock_pipeline = Mock() + mock_mesh = Mock() + + # Mock the return values of dependencies + mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock()) + mock_quantized_model_obj = Mock(spec=WanModel) + mock_quantize_model.return_value = mock_quantized_model_obj + + # Call the method under test + result = WanPipeline.quantize_transformer(mock_config, mock_model, mock_pipeline, mock_mesh) + + # Assertions + mock_get_dummy_inputs.assert_called_once() + mock_quantize_model.assert_called_once() + # Check that the model returned is the new quantized model + self.assertIs(result, mock_quantized_model_obj) + + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') + def test_quantize_transformer_disabled(self, mock_quantize_model): + """ + Tests that quantize_transformer is skipped when quantization is disabled. + """ + # Setup Mocks + mock_config = Mock(spec=HyperParameters) + mock_config.use_qwix_quantization = False # Main condition for this test + + mock_model = Mock(spec=WanModel) + + # Call the method under test + result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock()) + + # Assertions + mock_quantize_model.assert_not_called() + # Check that the model returned is the original model instance + self.assertIs(result, mock_model) if __name__ == "__main__":