Skip to content

Commit 4253c19

Browse files
committed
fix
1 parent 366deeb commit 4253c19

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,11 @@ def _unpack_audio_latents(
637637
if patch_size is not None and patch_size_t is not None:
638638
batch_size = latents.shape[0]
639639
# latents: (Batch, Seq, Dim)
640-
latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
641-
latents = latents.transpose(0, 3, 1, 4, 2, 5).reshape(batch_size, -1, latent_length * patch_size_t, num_mel_bins * patch_size)
640+
# Pack: (B, C, L, F) -> (B, C, L', pt, F', p) -> (B, C, L', pt, F', p) -> (B, L', F', C, pt, p) -> (B, L', F', C*pt*p)
641+
# Unpack: (B, L'*F', C*pt*p) -> (B, L', F', C, pt, p) -> (B, C, L', pt, F', p) -> (B, C, L'*pt, F'*p)
642+
latents = latents.reshape(batch_size, -1, num_mel_bins // patch_size, num_channels * patch_size_t * patch_size)
643+
latents = latents.reshape(batch_size, latent_length // patch_size_t, num_mel_bins // patch_size, num_channels, patch_size_t, patch_size)
644+
latents = latents.transpose(0, 3, 1, 4, 2, 5).reshape(batch_size, num_channels, latent_length, num_mel_bins)
642645
# Wait, reshape order needs to match pack?
643646
# Pack: (B, C, L, F) -> (B, C, L', pt, F', p) -> (B, L', F', C, pt, p) -> (B, L'*F', C*pt*p)
644647
# Unpack: (B, L'*F', C*pt*p) -> (B, L', F', C, pt, p) -> (B, C, L', pt, F', p) -> (B, C, L'*pt, F'*p)

src/maxdiffusion/tests/ltx2_pipeline_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_load_transformer(self, mock_load_config, mock_load_weights):
196196
real_model = LTX2VideoTransformer3DModel(**tiny_config, rngs=rngs)
197197

198198
graphdef, state = nnx.split(real_model)
199-
flat_state = state.to_flat_dict()
199+
flat_state = nnx.to_flat_state(state)
200200

201201
# Create mock weights that match real model structure
202202
# keys in flat_state are tuples like ('layer', 'kernel')

0 commit comments

Comments
 (0)