Skip to content
Merged

Qwix #259

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ compile_topology_num_slices: -1 # Number of target slices, set to a positive int
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"
qwix_module_path: ".*"

# Eval model on per eval_every steps. -1 means don't eval.
eval_every: -1
Expand Down
35 changes: 24 additions & 11 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,34 +243,48 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
return wan_vae, vae_cache

@classmethod
def get_basic_config(cls, dtype):
def get_basic_config(cls, dtype, config: HyperParameters):
rules = [
qwix.QtRule(
module_path=".*", # Apply to all modules
module_path=config.qwix_module_path,
weight_qtype=dtype,
act_qtype=dtype,
op_names=("dot_general","einsum", "conv_general_dilated"),
)
]
return rules

@classmethod
def get_fp8_config(cls, quantization_calibration_method: str):
def get_fp8_config(cls, config: HyperParameters):
"""
fp8 config rules with per-tensor calibration.
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
"""
rules = [
qwix.QtRule(
module_path=".*", # Apply to all modules
module_path=config.qwix_module_path,
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
bwd_calibration_method=config.quantization_calibration_method,
op_names=("dot_general","einsum"),
),
qwix.QtRule(
module_path=config.qwix_module_path,
weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=quantization_calibration_method,
act_calibration_method=quantization_calibration_method,
bwd_calibration_method=quantization_calibration_method,
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
bwd_calibration_method=config.quantization_calibration_method,
op_names=("conv_general_dilated"),
)
]
return rules
Expand All @@ -281,14 +295,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
if not getattr(config, "use_qwix_quantization", False):
return None

quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
match config.quantization:
case "int8":
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
return qwix.QtProvider(cls.get_basic_config(jnp.int8, config))
case "fp8":
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config))
case "fp8_full":
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
return qwix.QtProvider(cls.get_fp8_config(config))
return None

@classmethod
Expand Down
29 changes: 24 additions & 5 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax.numpy as jnp
import pytest
import unittest
from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, call
from absl.testing import absltest
from flax import nnx
from jax.sharding import Mesh
Expand Down Expand Up @@ -291,28 +291,43 @@ def test_get_qt_provider(self, mock_qt_rule):
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
config_int8.qwix_module_path = ".*"
provider_int8 = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8)
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8, op_names=("dot_general","einsum", "conv_general_dilated"))

# Case 3: Quantization enabled, type 'fp8'
mock_qt_rule.reset_mock()
config_fp8 = Mock(spec=HyperParameters)
config_fp8.use_qwix_quantization = True
config_fp8.quantization = "fp8"
config_int8.qwix_module_path = ".*"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn)
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated"))

# Case 4: Quantization enabled, type 'fp8_full'
mock_qt_rule.reset_mock()
config_fp8_full = Mock(spec=HyperParameters)
config_fp8_full.use_qwix_quantization = True
config_fp8_full.quantization = "fp8_full"
config_fp8_full.quantization_calibration_method = "absmax"
config_int8.qwix_module_path = ".*"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
mock_qt_rule.assert_called_once_with(
expected_calls = [
call(module_path=".*", # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config_fp8_full.quantization_calibration_method,
act_calibration_method=config_fp8_full.quantization_calibration_method,
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
op_names=("dot_general","einsum"),
),
call(
module_path=".*", # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
Expand All @@ -322,7 +337,10 @@ def test_get_qt_provider(self, mock_qt_rule):
weight_calibration_method=config_fp8_full.quantization_calibration_method,
act_calibration_method=config_fp8_full.quantization_calibration_method,
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
)
op_names=("conv_general_dilated"),
)
]
mock_qt_rule.assert_has_calls(expected_calls, any_order=True)

# Case 5: Invalid quantization type
config_invalid = Mock(spec=HyperParameters)
Expand All @@ -341,6 +359,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = True
mock_config.quantization = "fp8_full"
mock_config.qwix_module_path = ".*"
mock_config.per_device_batch_size = 1

mock_model = Mock(spec=WanModel)
Expand Down
Loading