Skip to content

Commit efe13bf

Browse files
committed
qwix quantize WAN transformer
1 parent fc46fcc commit efe13bf

2 files changed

Lines changed: 72 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,7 @@ quantization: ''
287287
# Shard the range finding operation for quantization. By default this is set to number of slices.
288288
quantization_local_shard_count: -1
289289
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
290+
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
291+
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
292+
quantization_calibration_method: "absmax"
290293

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
3434
from transformers import AutoTokenizer, UMT5EncoderModel
3535
from maxdiffusion.utils.import_utils import is_ftfy_available
36+
from ...maxdiffusion_utils import get_dummy_wan_inputs
3637
import html
3738
import re
3839
import torch
40+
import qwix
3941

4042

4143
def basic_clean(text):
@@ -225,6 +227,68 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
225227
vae_cache = AutoencoderKLWanCache(wan_vae)
226228
return wan_vae, vae_cache
227229

230+
@classmethod
231+
def get_basic_config(cls, dtype):
232+
rules = [
233+
qwix.QtRule(
234+
module_path='.*', # Apply to all modules
235+
weight_qtype=dtype,
236+
act_qtype=dtype,
237+
)
238+
]
239+
return rules
240+
241+
@classmethod
242+
def get_fp8_config(cls, quantization_calibration_method: str):
243+
""" fp8 config rules with per-tensor calibration.
244+
"""
245+
rules = [
246+
qwix.QtRule(
247+
module_path='.*', # Apply to all modules
248+
weight_qtype=jnp.float8_e4m3fn,
249+
act_qtype=jnp.float8_e4m3fn,
250+
bwd_qtype=jnp.float8_e5m2,
251+
bwd_use_original_residuals=True,
252+
disable_channelwise_axes=True, # per_tensor calibration
253+
weight_calibration_method = quantization_calibration_method,
254+
act_calibration_method = quantization_calibration_method,
255+
bwd_calibration_method = quantization_calibration_method,
256+
)
257+
]
258+
return rules
259+
260+
@classmethod
261+
def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
262+
"""Get quantization rules based on the config."""
263+
if not getattr(config, "use_qwix_quantization", False):
264+
return None
265+
266+
quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
267+
match config.quantization:
268+
case "int8":
269+
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
270+
case "fp8":
271+
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
272+
case "fp8_full":
273+
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
274+
return None
275+
276+
@classmethod
277+
def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh):
278+
"""Quantizes the transformer model."""
279+
q_rules = cls.get_qt_provider(config)
280+
if not q_rules:
281+
return model
282+
max_logging.log("Quantizing transformer with Qwix.")
283+
284+
batch_size = int(config.per_device_batch_size * jax.local_device_count())
285+
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
286+
model_inputs= (latents, timesteps, prompt_embeds)
287+
with mesh:
288+
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
289+
max_logging.log("Qwix Quantization complete.")
290+
return quantized_model
291+
228292
@classmethod
229293
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
230294
with mesh:
@@ -264,7 +328,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
264328
with mesh:
265329
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
266330

267-
return WanPipeline(
331+
pipeline = WanPipeline(
268332
tokenizer=tokenizer,
269333
text_encoder=text_encoder,
270334
transformer=transformer,
@@ -277,6 +341,10 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
277341
config=config,
278342
)
279343

344+
pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh)
345+
return pipeline
346+
347+
280348
def _get_t5_prompt_embeds(
281349
self,
282350
prompt: Union[str, List[str]] = None,

0 commit comments

Comments
 (0)