Skip to content

Commit 956828d

Browse files
committed
Merge branch 'main' into b_436918994
2 parents 1dcd0c9 + c44f0e5 commit 956828d

4 files changed

Lines changed: 17 additions & 10 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ To generate images, run the following command:
177177
## LTX-Video
178178
- In the folder src/maxdiffusion/models/ltx_video/utils, run:
179179
```bash
180-
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../xora_v1.2-13B-balanced-128.json
180+
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../ltxv-13B.json
181181
```
182182
- In the repo folder, run:
183183
```bash
184-
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json"
184+
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
185185
```
186186
- Other generation parameters can be set in ltx_video.yml file.
187187
## Flux

src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json renamed to src/maxdiffusion/models/ltx_video/ltxv-13B.json

File renamed without changes.

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,10 @@
6060

6161
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids):
6262
# Note: reference shape annotated for first pass default inference parameters
63-
max_logging.log("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32
64-
max_logging.log("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32
65-
max_logging.log("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32
66-
max_logging.log(
67-
"encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype
68-
) # (3, 256) int32
63+
max_logging.log(f"prompts_embeds.shape: {prompt_embeds.shape}") # (3, 256, 4096) float32
64+
max_logging.log(f"fractional_coords.shape: {fractional_coords.shape}") # (3, 3, 3072) float32
65+
max_logging.log(f"latents.shape: {latents.shape}") # (1, 3072, 128) float 32
66+
max_logging.log(f"encoder_attention_segment_ids.shape: {encoder_attention_segment_ids.shape}") # (3, 256) int32
6967

7068

7169
class LTXVideoPipeline:

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def test_wan_model(self):
278278
)
279279
assert dummy_output.shape == hidden_states_shape
280280

281-
def test_get_qt_provider(self):
281+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule')
282+
def test_get_qt_provider(self, mock_qt_rule):
282283
"""
283284
Tests the provider logic for all config branches.
284285
"""
@@ -293,9 +294,14 @@ def test_get_qt_provider(self):
293294
config_int8.quantization = "int8"
294295
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
295296
self.assertIsNotNone(provider_int8)
296-
self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8)
297+
mock_qt_rule.assert_called_once_with(
298+
module_path='.*',
299+
weight_qtype=jnp.int8,
300+
act_qtype=jnp.int8
301+
)
297302

298303
# Case 3: Quantization enabled, type 'fp8'
304+
mock_qt_rule.reset_mock()
299305
config_fp8 = Mock(spec=HyperParameters)
300306
config_fp8.use_qwix_quantization = True
301307
config_fp8.quantization = "fp8"
@@ -304,6 +310,7 @@ def test_get_qt_provider(self):
304310
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)
305311

306312
# Case 4: Quantization enabled, type 'fp8_full'
313+
mock_qt_rule.reset_mock()
307314
config_fp8_full = Mock(spec=HyperParameters)
308315
config_fp8_full.use_qwix_quantization = True
309316
config_fp8_full.quantization = "fp8_full"
@@ -334,6 +341,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
334341
mock_model = Mock(spec=WanModel)
335342
mock_pipeline = Mock()
336343
mock_mesh = Mock()
344+
mock_mesh.__enter__ = Mock(return_value=None)
345+
mock_mesh.__exit__ = Mock(return_value=None)
337346

338347
# Mock the return values of dependencies
339348
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())

0 commit comments

Comments
 (0)