Skip to content

Commit fcf215f

Browse files
committed
Formatting fixes
1 parent fd257bd commit fcf215f

2 files changed

Lines changed: 9 additions & 10 deletions

File tree

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,12 @@ def wan_init(raw_keys):
137137
else:
138138
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
139139
if "use_qwix_quantization" not in raw_keys:
140-
raise ValueError(f"use_qwix_quantization is not set.")
140+
raise ValueError("use_qwix_quantization is not set.")
141141
elif raw_keys["use_qwix_quantization"]:
142142
if "quantization" not in raw_keys:
143-
raise ValueError(f"Quantization type is not set when use_qwix_quantization is enabled.")
143+
raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.")
144144
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
145145
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
146-
147146

148147
@staticmethod
149148
def calculate_global_batch_sizes(per_device_batch_size):

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import jax.numpy as jnp
2020
import pytest
2121
import unittest
22-
from unittest.mock import Mock, patch, MagicMock
22+
from unittest.mock import Mock, patch
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
@@ -276,7 +276,7 @@ def test_wan_model(self):
276276
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
277277
)
278278
assert dummy_output.shape == hidden_states_shape
279-
279+
280280
def test_get_qt_provider(self):
281281
"""
282282
Tests the provider logic for all config branches.
@@ -301,7 +301,7 @@ def test_get_qt_provider(self):
301301
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
302302
self.assertIsNotNone(provider_fp8)
303303
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
304-
304+
305305
# Case 4: Quantization enabled, type 'fp8_full'
306306
config_fp8_full = Mock(spec=HyperParameters)
307307
config_fp8_full.use_qwix_quantization = True
@@ -329,11 +329,11 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
329329
mock_config.use_qwix_quantization = True
330330
mock_config.quantization = "fp8_full"
331331
mock_config.per_device_batch_size = 1
332-
332+
333333
mock_model = Mock(spec=WanModel)
334334
mock_pipeline = Mock()
335335
mock_mesh = Mock()
336-
336+
337337
# Mock the return values of dependencies
338338
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())
339339
mock_quantized_model_obj = Mock(spec=WanModel)
@@ -356,9 +356,9 @@ def test_quantize_transformer_disabled(self, mock_quantize_model):
356356
# Setup Mocks
357357
mock_config = Mock(spec=HyperParameters)
358358
mock_config.use_qwix_quantization = False # Main condition for this test
359-
359+
360360
mock_model = Mock(spec=WanModel)
361-
361+
362362
# Call the method under test
363363
result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock())
364364

0 commit comments

Comments
 (0)