Skip to content
Merged

Qwix #259

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ compile_topology_num_slices: -1 # Number of target slices, set to a positive int
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"
qwix_module_path: ".*"

# Eval model on per eval_every steps. -1 means don't eval.
eval_every: -1
Expand Down
25 changes: 13 additions & 12 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,34 +243,36 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
return wan_vae, vae_cache

@classmethod
def get_basic_config(cls, dtype):
def get_basic_config(cls, dtype, config: HyperParameters):
rules = [
qwix.QtRule(
module_path=".*", # Apply to all modules
module_path=config.qwix_module_path, # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
op_names=("dot_general",),
Comment thread
susanbao marked this conversation as resolved.
Outdated
)
]
return rules

@classmethod
def get_fp8_config(cls, quantization_calibration_method: str):
def get_fp8_config(cls, config: HyperParameters):
"""
fp8 config rules with per-tensor calibration.
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
"""
rules = [
qwix.QtRule(
module_path=".*", # Apply to all modules
module_path=config.qwix_module_path, # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=quantization_calibration_method,
act_calibration_method=quantization_calibration_method,
bwd_calibration_method=quantization_calibration_method,
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
bwd_calibration_method=config.quantization_calibration_method,
op_names=("dot_general",),
Comment thread
susanbao marked this conversation as resolved.
Outdated
)
]
return rules
Expand All @@ -281,14 +283,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
if not getattr(config, "use_qwix_quantization", False):
return None

quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
match config.quantization:
case "int8":
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
return qwix.QtProvider(cls.get_basic_config(jnp.int8, config))
case "fp8":
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config))
case "fp8_full":
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
return qwix.QtProvider(cls.get_fp8_config(config))
return None

@classmethod
Expand Down
Loading