From 1e64ebdde2feec0a57430124616a53b9ba239969 Mon Sep 17 00:00:00 2001 From: susanbao Date: Thu, 2 Oct 2025 18:11:48 +0000 Subject: [PATCH 1/3] qwix --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + .../pipelines/wan/wan_pipeline.py | 25 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 378741eb1..d6feba016 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -310,6 +310,7 @@ compile_topology_num_slices: -1 # Number of target slices, set to a positive int use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. # Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 quantization_calibration_method: "absmax" +qwix_module_path: ".*" # Eval model on per eval_every steps. -1 means don't eval. eval_every: -1 diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index e93c9884d..1537ee290 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -243,18 +243,19 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): return wan_vae, vae_cache @classmethod - def get_basic_config(cls, dtype): + def get_basic_config(cls, dtype, config: HyperParameters): rules = [ qwix.QtRule( - module_path=".*", # Apply to all modules + module_path=config.qwix_module_path, # Apply to all modules weight_qtype=dtype, act_qtype=dtype, + op_names=("dot_general",), ) ] return rules @classmethod - def get_fp8_config(cls, quantization_calibration_method: str): + def get_fp8_config(cls, config: HyperParameters): """ fp8 config rules with per-tensor calibration. FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api): @@ -262,15 +263,16 @@ def get_fp8_config(cls, quantization_calibration_method: str): """ rules = [ qwix.QtRule( - module_path=".*", # Apply to all modules + module_path=config.qwix_module_path, # Apply to all modules weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=quantization_calibration_method, - act_calibration_method=quantization_calibration_method, - bwd_calibration_method=quantization_calibration_method, + weight_calibration_method=config.quantization_calibration_method, + act_calibration_method=config.quantization_calibration_method, + bwd_calibration_method=config.quantization_calibration_method, + op_names=("dot_general",), ) ] return rules @@ -281,14 +283,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: if not getattr(config, "use_qwix_quantization", False): return None - quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax") match config.quantization: case "int8": - return qwix.QtProvider(cls.get_basic_config(jnp.int8)) + return qwix.QtProvider(cls.get_basic_config(jnp.int8, config)) case "fp8": - return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn)) + return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config)) case "fp8_full": - return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method)) + return qwix.QtProvider(cls.get_fp8_config(config)) return None @classmethod From e12aec8ae729e573c1fa41bc743f8bc3a6db6a9d Mon Sep 17 00:00:00 2001 From: susanbao Date: Mon, 6 Oct 2025 22:04:12 +0000 Subject: [PATCH 2/3] fix commits --- .../pipelines/wan/wan_pipeline.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1537ee290..c78d8bae2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -246,10 +246,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): def get_basic_config(cls, dtype, config: HyperParameters): rules = [ qwix.QtRule( - module_path=config.qwix_module_path, # Apply to all modules + module_path=config.qwix_module_path, weight_qtype=dtype, act_qtype=dtype, - op_names=("dot_general",), + op_names=("dot_general","einsum", "conv_general_dilated"), ) ] return rules @@ -263,7 +263,7 @@ def get_fp8_config(cls, config: HyperParameters): """ rules = [ qwix.QtRule( - module_path=config.qwix_module_path, # Apply to all modules + module_path=config.qwix_module_path, weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, @@ -272,7 +272,19 @@ def get_fp8_config(cls, config: HyperParameters): weight_calibration_method=config.quantization_calibration_method, act_calibration_method=config.quantization_calibration_method, bwd_calibration_method=config.quantization_calibration_method, - op_names=("dot_general",), + op_names=("dot_general","einsum"), + ), + qwix.QtRule( + module_path=config.qwix_module_path, + weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e4m3fn, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config.quantization_calibration_method, + act_calibration_method=config.quantization_calibration_method, + bwd_calibration_method=config.quantization_calibration_method, + op_names=("conv_general_dilated"), ) ] return rules From 1054ee150a61527127b29a21548585556ae238a6 Mon Sep 17 00:00:00 2001 From: susanbao Date: Tue, 7 Oct 2025 06:34:59 +0000 Subject: [PATCH 3/3] fix test --- .../tests/wan_transformer_test.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 26ea0f028..ae638e059 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -19,7 +19,7 @@ import jax.numpy as jnp import pytest import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, call from absl.testing import absltest from flax import nnx from jax.sharding import Mesh @@ -291,18 +291,20 @@ def test_get_qt_provider(self, mock_qt_rule): config_int8 = Mock(spec=HyperParameters) config_int8.use_qwix_quantization = True config_int8.quantization = "int8" + config_int8.qwix_module_path = ".*" provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) - mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8) + mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8, op_names=("dot_general","einsum", "conv_general_dilated")) # Case 3: Quantization enabled, type 'fp8' mock_qt_rule.reset_mock() config_fp8 = Mock(spec=HyperParameters) config_fp8.use_qwix_quantization = True config_fp8.quantization = "fp8" + config_int8.qwix_module_path = ".*" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn) + mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated")) # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() @@ -310,9 +312,22 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" config_fp8_full.quantization_calibration_method = "absmax" + config_int8.qwix_module_path = ".*" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) - mock_qt_rule.assert_called_once_with( + expected_calls = [ + call(module_path=".*", # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config_fp8_full.quantization_calibration_method, + act_calibration_method=config_fp8_full.quantization_calibration_method, + bwd_calibration_method=config_fp8_full.quantization_calibration_method, + op_names=("dot_general","einsum"), + ), + call( module_path=".*", # Apply to all modules weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, @@ -322,7 +337,10 @@ def test_get_qt_provider(self, mock_qt_rule): weight_calibration_method=config_fp8_full.quantization_calibration_method, act_calibration_method=config_fp8_full.quantization_calibration_method, bwd_calibration_method=config_fp8_full.quantization_calibration_method, - ) + op_names=("conv_general_dilated"), + ) + ] + mock_qt_rule.assert_has_calls(expected_calls, any_order=True) # Case 5: Invalid quantization type config_invalid = Mock(spec=HyperParameters) @@ -341,6 +359,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config = Mock(spec=HyperParameters) mock_config.use_qwix_quantization = True mock_config.quantization = "fp8_full" + mock_config.qwix_module_path = ".*" mock_config.per_device_batch_size = 1 mock_model = Mock(spec=WanModel)