Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
ruff check .
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,5 @@ cache_dreambooth_dataset: False
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,4 @@ quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
use_qwix_quantization: False
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,4 +246,5 @@ cache_dreambooth_dataset: False
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,6 @@ controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Goo
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,6 @@ controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Goo
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,5 @@ controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Goo
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,5 @@ enable_mllog: False
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,5 @@ enable_mllog: False
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,6 @@ cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
jit_initializers: True
enable_single_replica_ckpt_restoring: False
5 changes: 2 additions & 3 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,12 @@ def wan_init(raw_keys):
else:
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
if "use_qwix_quantization" not in raw_keys:
raise ValueError(f"use_qwix_quantization is not set.")
raise ValueError("use_qwix_quantization is not set.")
elif raw_keys["use_qwix_quantization"]:
if "quantization" not in raw_keys:
raise ValueError(f"Quantization type is not set when use_qwix_quantization is enabled.")
raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.")
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")


@staticmethod
def calculate_global_batch_sizes(per_device_batch_size):
Expand Down
20 changes: 10 additions & 10 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from qwix import QtProvider
import os
import jax
import jax.numpy as jnp
import pytest
import unittest
from unittest.mock import Mock, patch, MagicMock
from unittest.mock import Mock, patch
from absl.testing import absltest
from flax import nnx
from jax.sharding import Mesh
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_wan_model(self):
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
)
assert dummy_output.shape == hidden_states_shape

def test_get_qt_provider(self):
"""
Tests the provider logic for all config branches.
Expand All @@ -290,9 +290,9 @@ def test_get_qt_provider(self):
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
provider_int8 = WanPipeline.get_qt_provider(config_int8)
provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
self.assertEqual(provider_int8.rules[0].kwargs['weight_qtype'], jnp.int8)
self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8)

# Case 3: Quantization enabled, type 'fp8'
config_fp8 = Mock(spec=HyperParameters)
Expand All @@ -301,7 +301,7 @@ def test_get_qt_provider(self):
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)

# Case 4: Quantization enabled, type 'fp8_full'
config_fp8_full = Mock(spec=HyperParameters)
config_fp8_full.use_qwix_quantization = True
Expand Down Expand Up @@ -329,11 +329,11 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
mock_config.use_qwix_quantization = True
mock_config.quantization = "fp8_full"
mock_config.per_device_batch_size = 1

mock_model = Mock(spec=WanModel)
mock_pipeline = Mock()
mock_mesh = Mock()

# Mock the return values of dependencies
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())
mock_quantized_model_obj = Mock(spec=WanModel)
Expand All @@ -356,9 +356,9 @@ def test_quantize_transformer_disabled(self, mock_quantize_model):
# Setup Mocks
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = False # Main condition for this test

mock_model = Mock(spec=WanModel)

# Call the method under test
result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock())

Expand Down
8 changes: 4 additions & 4 deletions tests/schedulers/test_scheduler_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def test_full_loop_no_noise(self):
result_mean = jnp.mean(jnp.abs(sample))

if jax_device == "tpu":
assert abs(result_sum - 257.2727) < 1.5e-2
assert abs(result_mean - 0.3349905) < 1e-5
assert abs(result_sum - 257.29) < 1.5e-2
assert abs(result_mean - 0.3349905) < 2e-5
else:
assert abs(result_sum - 255.1113) < 1e-2
assert abs(result_mean - 0.332176) < 1e-3
Expand Down Expand Up @@ -919,7 +919,7 @@ def test_full_loop_with_set_alpha_to_one(self):
result_mean = jnp.mean(jnp.abs(sample))

if jax_device == "tpu":
assert abs(result_sum - 186.83226) < 1e-2
assert abs(result_sum - 186.83226) < 8e-2
assert abs(result_mean - 0.24327) < 1e-3
else:
assert abs(result_sum - 186.9466) < 1e-2
Expand All @@ -932,7 +932,7 @@ def test_full_loop_with_no_set_alpha_to_one(self):
result_mean = jnp.mean(jnp.abs(sample))

if jax_device == "tpu":
assert abs(result_sum - 186.83226) < 1e-2
assert abs(result_sum - 186.83226) < 8e-2
assert abs(result_mean - 0.24327) < 1e-3
else:
assert abs(result_sum - 186.9482) < 1e-2
Expand Down
Loading