Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ sentencepiece
aqtp
imageio==2.37.0
imageio-ffmpeg==0.6.0
hf_transfer>=0.1.9
hf_transfer>=0.1.9
qwix@git+https://github.com/google/qwix.git
3 changes: 2 additions & 1 deletion requirements_with_jax_ai_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ sentencepiece
aqtp
imageio==2.37.0
imageio-ffmpeg==0.6.0
hf_transfer>=0.1.9
hf_transfer>=0.1.9
qwix@git+https://github.com/google/qwix.git
Comment thread
susanbao marked this conversation as resolved.
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,7 @@ quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"

72 changes: 71 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
from transformers import AutoTokenizer, UMT5EncoderModel
from maxdiffusion.utils.import_utils import is_ftfy_available
from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs
import html
import re
import torch
import qwix


def basic_clean(text):
Expand Down Expand Up @@ -225,6 +227,71 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
vae_cache = AutoencoderKLWanCache(wan_vae)
return wan_vae, vae_cache

@classmethod
def get_basic_config(cls, dtype):
rules = [
qwix.QtRule(
module_path='.*', # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
)
]
return rules

@classmethod
def get_fp8_config(cls, quantization_calibration_method: str):
"""
fp8 config rules with per-tensor calibration.
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
"""
rules = [
qwix.QtRule(
module_path='.*', # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
Comment thread
coolkp marked this conversation as resolved.
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method = quantization_calibration_method,
act_calibration_method = quantization_calibration_method,
bwd_calibration_method = quantization_calibration_method,
)
]
return rules

@classmethod
def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
"""Get quantization rules based on the config."""
if not getattr(config, "use_qwix_quantization", False):
return None

quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
match config.quantization:
case "int8":
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
case "fp8":
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
case "fp8_full":
Comment thread
coolkp marked this conversation as resolved.
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
return None

@classmethod
def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh):
"""Quantizes the transformer model."""
q_rules = cls.get_qt_provider(config)
if not q_rules:
return model
max_logging.log("Quantizing transformer with Qwix.")

batch_size = int(config.per_device_batch_size * jax.local_device_count())
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
model_inputs= (latents, timesteps, prompt_embeds)
with mesh:
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
max_logging.log("Qwix Quantization complete.")
return quantized_model

@classmethod
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
with mesh:
Expand Down Expand Up @@ -264,7 +331,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
with mesh:
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)

return WanPipeline(
pipeline = WanPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
Expand All @@ -277,6 +344,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
config=config,
)

pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh)
return pipeline

def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
Expand Down
8 changes: 8 additions & 0 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def wan_init(raw_keys):
)
else:
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
if "use_qwix_quantization" not in raw_keys:
raise ValueError(f"use_qwix_quantization is not set.")
elif raw_keys["use_qwix_quantization"]:
if "quantization" not in raw_keys:
raise ValueError(f"Quantization type is not set when use_qwix_quantization is enabled.")
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")


@staticmethod
def calculate_global_batch_sizes(per_device_batch_size):
Expand Down
94 changes: 94 additions & 0 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.numpy as jnp
import pytest
import unittest
from unittest.mock import Mock, patch, MagicMock
from absl.testing import absltest
from flax import nnx
from jax.sharding import Mesh
Expand All @@ -34,6 +35,9 @@
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
from ..models.normalization_flax import FP32LayerNorm
from ..models.attention_flax import FlaxWanAttention
from maxdiffusion.pyconfig import HyperParameters
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline


IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"

Expand Down Expand Up @@ -272,6 +276,96 @@ def test_wan_model(self):
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
)
assert dummy_output.shape == hidden_states_shape

def test_get_qt_provider(self):
"""
Tests the provider logic for all config branches.
"""
# Case 1: Quantization disabled
config_disabled = Mock(spec=HyperParameters)
config_disabled.use_qwix_quantization = False
self.assertIsNone(WanPipeline.get_qt_provider(config_disabled))

# Case 2: Quantization enabled, type 'int8'
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
provider_int8 = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
self.assertEqual(provider_int8.rules[0].kwargs['weight_qtype'], jnp.int8)

# Case 3: Quantization enabled, type 'fp8'
config_fp8 = Mock(spec=HyperParameters)
config_fp8.use_qwix_quantization = True
config_fp8.quantization = "fp8"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)

# Case 4: Quantization enabled, type 'fp8_full'
config_fp8_full = Mock(spec=HyperParameters)
config_fp8_full.use_qwix_quantization = True
config_fp8_full.quantization = "fp8_full"
config_fp8_full.quantization_calibration_method = "absmax"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)

# Case 5: Invalid quantization type
config_invalid = Mock(spec=HyperParameters)
config_invalid.use_qwix_quantization = True
config_invalid.quantization = "invalid_type"
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))

# To test quantize_transformer, we patch its external dependencies
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
"""
Tests that quantize_transformer calls qwix when quantization is enabled.
"""
# Setup Mocks
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = True
mock_config.quantization = "fp8_full"
mock_config.per_device_batch_size = 1

mock_model = Mock(spec=WanModel)
mock_pipeline = Mock()
mock_mesh = Mock()

# Mock the return values of dependencies
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())
mock_quantized_model_obj = Mock(spec=WanModel)
mock_quantize_model.return_value = mock_quantized_model_obj

# Call the method under test
result = WanPipeline.quantize_transformer(mock_config, mock_model, mock_pipeline, mock_mesh)

# Assertions
mock_get_dummy_inputs.assert_called_once()
mock_quantize_model.assert_called_once()
# Check that the model returned is the new quantized model
self.assertIs(result, mock_quantized_model_obj)

@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
def test_quantize_transformer_disabled(self, mock_quantize_model):
"""
Tests that quantize_transformer is skipped when quantization is disabled.
"""
# Setup Mocks
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = False # Main condition for this test

mock_model = Mock(spec=WanModel)

# Call the method under test
result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock())

# Assertions
mock_quantize_model.assert_not_called()
# Check that the model returned is the original model instance
self.assertIs(result, mock_model)


if __name__ == "__main__":
Expand Down
Loading