Skip to content

Commit de60c6c

Browse files
authored
qwix quantize WAN transformer (#226)
* qwix quantize WAN transformer * Solve all review problem
1 parent ecd3514 commit de60c6c

6 files changed

Lines changed: 180 additions & 3 deletions

File tree

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ sentencepiece
3535
aqtp
3636
imageio==2.37.0
3737
imageio-ffmpeg==0.6.0
38-
hf_transfer>=0.1.9
38+
hf_transfer>=0.1.9
39+
qwix@git+https://github.com/google/qwix.git

requirements_with_jax_ai_image.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ sentencepiece
3535
aqtp
3636
imageio==2.37.0
3737
imageio-ffmpeg==0.6.0
38-
hf_transfer>=0.1.9
38+
hf_transfer>=0.1.9
39+
qwix@git+https://github.com/google/qwix.git

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,7 @@ quantization: ''
287287
# Shard the range finding operation for quantization. By default this is set to number of slices.
288288
quantization_local_shard_count: -1
289289
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
290+
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
291+
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
292+
quantization_calibration_method: "absmax"
290293

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
3434
from transformers import AutoTokenizer, UMT5EncoderModel
3535
from maxdiffusion.utils.import_utils import is_ftfy_available
36+
from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs
3637
import html
3738
import re
3839
import torch
40+
import qwix
3941

4042

4143
def basic_clean(text):
@@ -225,6 +227,71 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
225227
vae_cache = AutoencoderKLWanCache(wan_vae)
226228
return wan_vae, vae_cache
227229

230+
@classmethod
231+
def get_basic_config(cls, dtype):
232+
rules = [
233+
qwix.QtRule(
234+
module_path='.*', # Apply to all modules
235+
weight_qtype=dtype,
236+
act_qtype=dtype,
237+
)
238+
]
239+
return rules
240+
241+
@classmethod
242+
def get_fp8_config(cls, quantization_calibration_method: str):
243+
"""
244+
fp8 config rules with per-tensor calibration.
245+
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
246+
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
247+
"""
248+
rules = [
249+
qwix.QtRule(
250+
module_path='.*', # Apply to all modules
251+
weight_qtype=jnp.float8_e4m3fn,
252+
act_qtype=jnp.float8_e4m3fn,
253+
bwd_qtype=jnp.float8_e5m2,
254+
bwd_use_original_residuals=True,
255+
disable_channelwise_axes=True, # per_tensor calibration
256+
weight_calibration_method = quantization_calibration_method,
257+
act_calibration_method = quantization_calibration_method,
258+
bwd_calibration_method = quantization_calibration_method,
259+
)
260+
]
261+
return rules
262+
263+
@classmethod
264+
def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
265+
"""Get quantization rules based on the config."""
266+
if not getattr(config, "use_qwix_quantization", False):
267+
return None
268+
269+
quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
270+
match config.quantization:
271+
case "int8":
272+
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
273+
case "fp8":
274+
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
275+
case "fp8_full":
276+
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
277+
return None
278+
279+
@classmethod
280+
def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh):
281+
"""Quantizes the transformer model."""
282+
q_rules = cls.get_qt_provider(config)
283+
if not q_rules:
284+
return model
285+
max_logging.log("Quantizing transformer with Qwix.")
286+
287+
batch_size = int(config.per_device_batch_size * jax.local_device_count())
288+
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
289+
model_inputs= (latents, timesteps, prompt_embeds)
290+
with mesh:
291+
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
292+
max_logging.log("Qwix Quantization complete.")
293+
return quantized_model
294+
228295
@classmethod
229296
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
230297
with mesh:
@@ -264,7 +331,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
264331
with mesh:
265332
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
266333

267-
return WanPipeline(
334+
pipeline = WanPipeline(
268335
tokenizer=tokenizer,
269336
text_encoder=text_encoder,
270337
transformer=transformer,
@@ -277,6 +344,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
277344
config=config,
278345
)
279346

347+
pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh)
348+
return pipeline
349+
280350
def _get_t5_prompt_embeds(
281351
self,
282352
prompt: Union[str, List[str]] = None,

src/maxdiffusion/pyconfig.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def wan_init(raw_keys):
136136
)
137137
else:
138138
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
139+
if "use_qwix_quantization" not in raw_keys:
140+
raise ValueError(f"use_qwix_quantization is not set.")
141+
elif raw_keys["use_qwix_quantization"]:
142+
if "quantization" not in raw_keys:
143+
raise ValueError(f"Quantization type is not set when use_qwix_quantization is enabled.")
144+
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
145+
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
146+
139147

140148
@staticmethod
141149
def calculate_global_batch_sizes(per_device_batch_size):

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax.numpy as jnp
2020
import pytest
2121
import unittest
22+
from unittest.mock import Mock, patch, MagicMock
2223
from absl.testing import absltest
2324
from flax import nnx
2425
from jax.sharding import Mesh
@@ -34,6 +35,9 @@
3435
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
3536
from ..models.normalization_flax import FP32LayerNorm
3637
from ..models.attention_flax import FlaxWanAttention
38+
from maxdiffusion.pyconfig import HyperParameters
39+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
40+
3741

3842
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
3943

@@ -272,6 +276,96 @@ def test_wan_model(self):
272276
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
273277
)
274278
assert dummy_output.shape == hidden_states_shape
279+
280+
def test_get_qt_provider(self):
281+
"""
282+
Tests the provider logic for all config branches.
283+
"""
284+
# Case 1: Quantization disabled
285+
config_disabled = Mock(spec=HyperParameters)
286+
config_disabled.use_qwix_quantization = False
287+
self.assertIsNone(WanPipeline.get_qt_provider(config_disabled))
288+
289+
# Case 2: Quantization enabled, type 'int8'
290+
config_int8 = Mock(spec=HyperParameters)
291+
config_int8.use_qwix_quantization = True
292+
config_int8.quantization = "int8"
293+
provider_int8 = WanPipeline.get_qt_provider(config_int8)
294+
self.assertIsNotNone(provider_int8)
295+
self.assertEqual(provider_int8.rules[0].kwargs['weight_qtype'], jnp.int8)
296+
297+
# Case 3: Quantization enabled, type 'fp8'
298+
config_fp8 = Mock(spec=HyperParameters)
299+
config_fp8.use_qwix_quantization = True
300+
config_fp8.quantization = "fp8"
301+
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
302+
self.assertIsNotNone(provider_fp8)
303+
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
304+
305+
# Case 4: Quantization enabled, type 'fp8_full'
306+
config_fp8_full = Mock(spec=HyperParameters)
307+
config_fp8_full.use_qwix_quantization = True
308+
config_fp8_full.quantization = "fp8_full"
309+
config_fp8_full.quantization_calibration_method = "absmax"
310+
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
311+
self.assertIsNotNone(provider_fp8_full)
312+
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)
313+
314+
# Case 5: Invalid quantization type
315+
config_invalid = Mock(spec=HyperParameters)
316+
config_invalid.use_qwix_quantization = True
317+
config_invalid.quantization = "invalid_type"
318+
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))
319+
320+
# To test quantize_transformer, we patch its external dependencies
321+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
322+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
323+
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
324+
"""
325+
Tests that quantize_transformer calls qwix when quantization is enabled.
326+
"""
327+
# Setup Mocks
328+
mock_config = Mock(spec=HyperParameters)
329+
mock_config.use_qwix_quantization = True
330+
mock_config.quantization = "fp8_full"
331+
mock_config.per_device_batch_size = 1
332+
333+
mock_model = Mock(spec=WanModel)
334+
mock_pipeline = Mock()
335+
mock_mesh = Mock()
336+
337+
# Mock the return values of dependencies
338+
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())
339+
mock_quantized_model_obj = Mock(spec=WanModel)
340+
mock_quantize_model.return_value = mock_quantized_model_obj
341+
342+
# Call the method under test
343+
result = WanPipeline.quantize_transformer(mock_config, mock_model, mock_pipeline, mock_mesh)
344+
345+
# Assertions
346+
mock_get_dummy_inputs.assert_called_once()
347+
mock_quantize_model.assert_called_once()
348+
# Check that the model returned is the new quantized model
349+
self.assertIs(result, mock_quantized_model_obj)
350+
351+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
352+
def test_quantize_transformer_disabled(self, mock_quantize_model):
353+
"""
354+
Tests that quantize_transformer is skipped when quantization is disabled.
355+
"""
356+
# Setup Mocks
357+
mock_config = Mock(spec=HyperParameters)
358+
mock_config.use_qwix_quantization = False # Main condition for this test
359+
360+
mock_model = Mock(spec=WanModel)
361+
362+
# Call the method under test
363+
result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock())
364+
365+
# Assertions
366+
mock_quantize_model.assert_not_called()
367+
# Check that the model returned is the original model instance
368+
self.assertIs(result, mock_model)
275369

276370

277371
if __name__ == "__main__":

0 commit comments

Comments
 (0)