Skip to content

Commit e12aec8

Browse files
committed
fix commits
1 parent 1e64ebd commit e12aec8

1 file changed

Lines changed: 16 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
246246
def get_basic_config(cls, dtype, config: HyperParameters):
247247
rules = [
248248
qwix.QtRule(
249-
module_path=config.qwix_module_path, # Apply to all modules
249+
module_path=config.qwix_module_path,
250250
weight_qtype=dtype,
251251
act_qtype=dtype,
252-
op_names=("dot_general",),
252+
op_names=("dot_general","einsum", "conv_general_dilated"),
253253
)
254254
]
255255
return rules
@@ -263,7 +263,7 @@ def get_fp8_config(cls, config: HyperParameters):
263263
"""
264264
rules = [
265265
qwix.QtRule(
266-
module_path=config.qwix_module_path, # Apply to all modules
266+
module_path=config.qwix_module_path,
267267
weight_qtype=jnp.float8_e4m3fn,
268268
act_qtype=jnp.float8_e4m3fn,
269269
bwd_qtype=jnp.float8_e5m2,
@@ -272,7 +272,19 @@ def get_fp8_config(cls, config: HyperParameters):
272272
weight_calibration_method=config.quantization_calibration_method,
273273
act_calibration_method=config.quantization_calibration_method,
274274
bwd_calibration_method=config.quantization_calibration_method,
275-
op_names=("dot_general",),
275+
op_names=("dot_general","einsum"),
276+
),
277+
qwix.QtRule(
278+
module_path=config.qwix_module_path,
279+
weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes
280+
act_qtype=jnp.float8_e4m3fn,
281+
bwd_qtype=jnp.float8_e4m3fn,
282+
bwd_use_original_residuals=True,
283+
disable_channelwise_axes=True, # per_tensor calibration
284+
weight_calibration_method=config.quantization_calibration_method,
285+
act_calibration_method=config.quantization_calibration_method,
286+
bwd_calibration_method=config.quantization_calibration_method,
287+
op_names=("conv_general_dilated"),
276288
)
277289
]
278290
return rules

0 commit comments

Comments
 (0)