Skip to content

Commit 8fd07f4

Browse files
committed
Solve all review problem
1 parent efe13bf commit 8fd07f4

5 files changed

Lines changed: 111 additions & 5 deletions

File tree

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ sentencepiece
3535
aqtp
3636
imageio==2.37.0
3737
imageio-ffmpeg==0.6.0
38-
hf_transfer>=0.1.9
38+
hf_transfer>=0.1.9
39+
qwix@git+https://github.com/google/qwix.git

requirements_with_jax_ai_image.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ sentencepiece
3535
aqtp
3636
imageio==2.37.0
3737
imageio-ffmpeg==0.6.0
38-
hf_transfer>=0.1.9
38+
hf_transfer>=0.1.9
39+
qwix@git+https://github.com/google/qwix.git

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
3434
from transformers import AutoTokenizer, UMT5EncoderModel
3535
from maxdiffusion.utils.import_utils import is_ftfy_available
36-
from ...maxdiffusion_utils import get_dummy_wan_inputs
36+
from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs
3737
import html
3838
import re
3939
import torch
@@ -240,7 +240,10 @@ def get_basic_config(cls, dtype):
240240

241241
@classmethod
242242
def get_fp8_config(cls, quantization_calibration_method: str):
243-
""" fp8 config rules with per-tensor calibration.
243+
"""
244+
fp8 config rules with per-tensor calibration.
245+
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
246+
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
244247
"""
245248
rules = [
246249
qwix.QtRule(
@@ -344,7 +347,6 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
344347
pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh)
345348
return pipeline
346349

347-
348350
def _get_t5_prompt_embeds(
349351
self,
350352
prompt: Union[str, List[str]] = None,

src/maxdiffusion/pyconfig.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def wan_init(raw_keys):
136136
)
137137
else:
138138
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
139+
if "use_qwix_quantization" not in raw_keys:
140+
raise ValueError(f"use_qwix_quantization is not set.")
141+
elif raw_keys["use_qwix_quantization"]:
142+
if "quantization" not in raw_keys:
143+
raise ValueError(f"Quantization type is not set when use_qwix_quantization is enabled.")
144+
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
145+
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
146+
139147

140148
@staticmethod
141149
def calculate_global_batch_sizes(per_device_batch_size):

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax.numpy as jnp
2020
import pytest
2121
import unittest
22+
from unittest.mock import Mock, patch, MagicMock
2223
from absl.testing import absltest
2324
from flax import nnx
2425
from jax.sharding import Mesh
@@ -34,6 +35,9 @@
3435
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
3536
from ..models.normalization_flax import FP32LayerNorm
3637
from ..models.attention_flax import FlaxWanAttention
38+
from maxdiffusion.pyconfig import HyperParameters
39+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
40+
3741

3842
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
3943

@@ -272,6 +276,96 @@ def test_wan_model(self):
272276
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
273277
)
274278
assert dummy_output.shape == hidden_states_shape
279+
280+
def test_get_qt_provider(self):
281+
"""
282+
Tests the provider logic for all config branches.
283+
"""
284+
# Case 1: Quantization disabled
285+
config_disabled = Mock(spec=HyperParameters)
286+
config_disabled.use_qwix_quantization = False
287+
self.assertIsNone(WanPipeline.get_qt_provider(config_disabled))
288+
289+
# Case 2: Quantization enabled, type 'int8'
290+
config_int8 = Mock(spec=HyperParameters)
291+
config_int8.use_qwix_quantization = True
292+
config_int8.quantization = "int8"
293+
provider_int8 = WanPipeline.get_qt_provider(config_int8)
294+
self.assertIsNotNone(provider_int8)
295+
self.assertEqual(provider_int8.rules[0].kwargs['weight_qtype'], jnp.int8)
296+
297+
# Case 3: Quantization enabled, type 'fp8'
298+
config_fp8 = Mock(spec=HyperParameters)
299+
config_fp8.use_qwix_quantization = True
300+
config_fp8.quantization = "fp8"
301+
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
302+
self.assertIsNotNone(provider_fp8)
303+
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
304+
305+
# Case 4: Quantization enabled, type 'fp8_full'
306+
config_fp8_full = Mock(spec=HyperParameters)
307+
config_fp8_full.use_qwix_quantization = True
308+
config_fp8_full.quantization = "fp8_full"
309+
config_fp8_full.quantization_calibration_method = "absmax"
310+
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
311+
self.assertIsNotNone(provider_fp8_full)
312+
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)
313+
314+
# Case 5: Invalid quantization type
315+
config_invalid = Mock(spec=HyperParameters)
316+
config_invalid.use_qwix_quantization = True
317+
config_invalid.quantization = "invalid_type"
318+
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))
319+
320+
# To test quantize_transformer, we patch its external dependencies
321+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
322+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
323+
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
324+
"""
325+
Tests that quantize_transformer calls qwix when quantization is enabled.
326+
"""
327+
# Setup Mocks
328+
mock_config = Mock(spec=HyperParameters)
329+
mock_config.use_qwix_quantization = True
330+
mock_config.quantization = "fp8_full"
331+
mock_config.per_device_batch_size = 1
332+
333+
mock_model = Mock(spec=WanModel)
334+
mock_pipeline = Mock()
335+
mock_mesh = Mock()
336+
337+
# Mock the return values of dependencies
338+
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())
339+
mock_quantized_model_obj = Mock(spec=WanModel)
340+
mock_quantize_model.return_value = mock_quantized_model_obj
341+
342+
# Call the method under test
343+
result = WanPipeline.quantize_transformer(mock_config, mock_model, mock_pipeline, mock_mesh)
344+
345+
# Assertions
346+
mock_get_dummy_inputs.assert_called_once()
347+
mock_quantize_model.assert_called_once()
348+
# Check that the model returned is the new quantized model
349+
self.assertIs(result, mock_quantized_model_obj)
350+
351+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
352+
def test_quantize_transformer_disabled(self, mock_quantize_model):
353+
"""
354+
Tests that quantize_transformer is skipped when quantization is disabled.
355+
"""
356+
# Setup Mocks
357+
mock_config = Mock(spec=HyperParameters)
358+
mock_config.use_qwix_quantization = False # Main condition for this test
359+
360+
mock_model = Mock(spec=WanModel)
361+
362+
# Call the method under test
363+
result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock())
364+
365+
# Assertions
366+
mock_quantize_model.assert_not_called()
367+
# Check that the model returned is the original model instance
368+
self.assertIs(result, mock_model)
275369

276370

277371
if __name__ == "__main__":

0 commit comments

Comments
 (0)