Skip to content

Commit 1e64ebd

Browse files
committed
qwix
1 parent 3a9d12b commit 1e64ebd

2 files changed

Lines changed: 14 additions & 12 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ compile_topology_num_slices: -1 # Number of target slices, set to a positive int
310310
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
311311
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
312312
quantization_calibration_method: "absmax"
313+
qwix_module_path: ".*"
313314

314315
# Eval model on per eval_every steps. -1 means don't eval.
315316
eval_every: -1

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -243,34 +243,36 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
243243
return wan_vae, vae_cache
244244

245245
@classmethod
246-
def get_basic_config(cls, dtype):
246+
def get_basic_config(cls, dtype, config: HyperParameters):
247247
rules = [
248248
qwix.QtRule(
249-
module_path=".*", # Apply to all modules
249+
module_path=config.qwix_module_path, # Apply to all modules
250250
weight_qtype=dtype,
251251
act_qtype=dtype,
252+
op_names=("dot_general",),
252253
)
253254
]
254255
return rules
255256

256257
@classmethod
257-
def get_fp8_config(cls, quantization_calibration_method: str):
258+
def get_fp8_config(cls, config: HyperParameters):
258259
"""
259260
fp8 config rules with per-tensor calibration.
260261
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
261262
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
262263
"""
263264
rules = [
264265
qwix.QtRule(
265-
module_path=".*", # Apply to all modules
266+
module_path=config.qwix_module_path, # Apply to all modules
266267
weight_qtype=jnp.float8_e4m3fn,
267268
act_qtype=jnp.float8_e4m3fn,
268-
bwd_qtype=jnp.float8_e4m3fn,
269+
bwd_qtype=jnp.float8_e5m2,
269270
bwd_use_original_residuals=True,
270271
disable_channelwise_axes=True, # per_tensor calibration
271-
weight_calibration_method=quantization_calibration_method,
272-
act_calibration_method=quantization_calibration_method,
273-
bwd_calibration_method=quantization_calibration_method,
272+
weight_calibration_method=config.quantization_calibration_method,
273+
act_calibration_method=config.quantization_calibration_method,
274+
bwd_calibration_method=config.quantization_calibration_method,
275+
op_names=("dot_general",),
274276
)
275277
]
276278
return rules
@@ -281,14 +283,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
281283
if not getattr(config, "use_qwix_quantization", False):
282284
return None
283285

284-
quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
285286
match config.quantization:
286287
case "int8":
287-
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
288+
return qwix.QtProvider(cls.get_basic_config(jnp.int8, config))
288289
case "fp8":
289-
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
290+
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config))
290291
case "fp8_full":
291-
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
292+
return qwix.QtProvider(cls.get_fp8_config(config))
292293
return None
293294

294295
@classmethod

0 commit comments

Comments
 (0)