Skip to content

Commit 1dcd0c9

Browse files
committed
lint.
1 parent 3e540c5 commit 1dcd0c9

4 files changed

Lines changed: 30 additions & 25 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
233233
def get_basic_config(cls, dtype):
234234
rules = [
235235
qwix.QtRule(
236-
module_path='.*', # Apply to all modules
237-
weight_qtype=dtype,
238-
act_qtype=dtype,
236+
module_path=".*", # Apply to all modules
237+
weight_qtype=dtype,
238+
act_qtype=dtype,
239239
)
240-
]
240+
]
241241
return rules
242242

243243
@classmethod
@@ -249,17 +249,17 @@ def get_fp8_config(cls, quantization_calibration_method: str):
249249
"""
250250
rules = [
251251
qwix.QtRule(
252-
module_path='.*', # Apply to all modules
253-
weight_qtype=jnp.float8_e4m3fn,
254-
act_qtype=jnp.float8_e4m3fn,
255-
bwd_qtype=jnp.float8_e5m2,
256-
bwd_use_original_residuals=True,
257-
disable_channelwise_axes=True, # per_tensor calibration
258-
weight_calibration_method = quantization_calibration_method,
259-
act_calibration_method = quantization_calibration_method,
260-
bwd_calibration_method = quantization_calibration_method,
252+
module_path=".*", # Apply to all modules
253+
weight_qtype=jnp.float8_e4m3fn,
254+
act_qtype=jnp.float8_e4m3fn,
255+
bwd_qtype=jnp.float8_e5m2,
256+
bwd_use_original_residuals=True,
257+
disable_channelwise_axes=True, # per_tensor calibration
258+
weight_calibration_method=quantization_calibration_method,
259+
act_calibration_method=quantization_calibration_method,
260+
bwd_calibration_method=quantization_calibration_method,
261261
)
262-
]
262+
]
263263
return rules
264264

265265
@classmethod
@@ -288,7 +288,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline
288288

289289
batch_size = int(config.per_device_batch_size * jax.local_device_count())
290290
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
291-
model_inputs= (latents, timesteps, prompt_embeds)
291+
model_inputs = (latents, timesteps, prompt_embeds)
292292
with mesh:
293293
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
294294
max_logging.log("Qwix Quantization complete.")

src/maxdiffusion/pyconfig.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def wan_init(raw_keys):
142142
if "quantization" not in raw_keys:
143143
raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.")
144144
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
145-
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
145+
raise ValueError(
146+
f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}"
147+
)
146148

147149
@staticmethod
148150
def calculate_global_batch_sizes(per_device_batch_size):

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from qwix import QtProvider
1718
import os
1819
import jax
@@ -290,7 +291,7 @@ def test_get_qt_provider(self):
290291
config_int8 = Mock(spec=HyperParameters)
291292
config_int8.use_qwix_quantization = True
292293
config_int8.quantization = "int8"
293-
provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8)
294+
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
294295
self.assertIsNotNone(provider_int8)
295296
self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8)
296297

@@ -300,7 +301,7 @@ def test_get_qt_provider(self):
300301
config_fp8.quantization = "fp8"
301302
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
302303
self.assertIsNotNone(provider_fp8)
303-
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
304+
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)
304305

305306
# Case 4: Quantization enabled, type 'fp8_full'
306307
config_fp8_full = Mock(spec=HyperParameters)
@@ -309,7 +310,7 @@ def test_get_qt_provider(self):
309310
config_fp8_full.quantization_calibration_method = "absmax"
310311
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
311312
self.assertIsNotNone(provider_fp8_full)
312-
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)
313+
self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2)
313314

314315
# Case 5: Invalid quantization type
315316
config_invalid = Mock(spec=HyperParameters)
@@ -318,8 +319,8 @@ def test_get_qt_provider(self):
318319
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))
319320

320321
# To test quantize_transformer, we patch its external dependencies
321-
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
322-
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
322+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
323+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs")
323324
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
324325
"""
325326
Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -348,14 +349,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
348349
# Check that the model returned is the new quantized model
349350
self.assertIs(result, mock_quantized_model_obj)
350351

351-
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
352+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
352353
def test_quantize_transformer_disabled(self, mock_quantize_model):
353354
"""
354355
Tests that quantize_transformer is skipped when quantization is disabled.
355356
"""
356357
# Setup Mocks
357358
mock_config = Mock(spec=HyperParameters)
358-
mock_config.use_qwix_quantization = False # Main condition for this test
359+
mock_config.use_qwix_quantization = False # Main condition for this test
359360

360361
mock_model = Mock(spec=WanModel)
361362

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def generate_sample(config, pipeline, filename_prefix):
5555
Generates a video to validate training did not corrupt the model
5656
"""
5757
if not hasattr(pipeline, "vae"):
58-
wan_vae, vae_cache = WanPipeline.load_vae(pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config)
58+
wan_vae, vae_cache = WanPipeline.load_vae(
59+
pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config
60+
)
5961
pipeline.vae = wan_vae
6062
pipeline.vae_cache = vae_cache
6163
return generate_wan(config, pipeline, filename_prefix)
@@ -147,7 +149,7 @@ def start_training(self):
147149
pipeline = self.load_checkpoint()
148150
# Generate a sample before training to compare against generated sample after training.
149151
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
150-
152+
151153
# save some memory.
152154
del pipeline.vae
153155
del pipeline.vae_cache

0 commit comments

Comments
 (0)