Skip to content

Commit 138f1eb

Browse files
committed
text encoder used from hf
1 parent 330b0ce commit 138f1eb

2 files changed

Lines changed: 31 additions & 38 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,7 @@ def _get_gemma_prompt_embeds(
402402
text_encoder_hidden_states = jnp.array(text_encoder_hidden_states.cpu().numpy())
403403
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().numpy())
404404
else:
405-
# Mock hidden states
406-
# Should be removed once we have actual text_encoder ready to port
407-
hidden_dim = 1024
408-
num_layers = 2
409-
text_encoder_hidden_states = jnp.zeros(
410-
(batch_size, max_sequence_length, hidden_dim, num_layers), dtype=dtype or jnp.float32
411-
)
405+
raise ValueError("`text_encoder` is required to encode prompts.")
412406

413407
sequence_lengths = prompt_attention_mask.sum(axis=-1)
414408

@@ -605,28 +599,6 @@ def _create_noised_state(
605599
@staticmethod
606600
def _pack_audio_latents(
607601
latents: jax.Array, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
608-
) -> jax.Array:
609-
if patch_size is not None and patch_size_t is not None:
610-
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
611-
post_patch_latent_length = latent_length // patch_size_t
612-
post_patch_mel_bins = latent_mel_bins // patch_size
613-
latents = latents.reshape(
614-
batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
615-
)
616-
latents = latents.transpose(0, 2, 4, 1, 3, 5).reshape(batch_size, post_patch_latent_length * post_patch_mel_bins, -1)
617-
else:
618-
latents = latents.transpose(0, 2, 1).reshape(batch_size, latents.shape[2], -1)
619-
# Wait, original was transpose(1,2).flatten(2,3) -> (Batch, Channels, Length) -> (Batch, Length, Channels)?
620-
# Diffusers: latents = latents.transpose(1, 2).flatten(2, 3)
621-
# (B, C, L) -> (B, L, C).
622-
# If 4D: (B, C, L, M) -> (B, C, L, P_t, M, P) -> ...
623-
pass
624-
return latents
625-
626-
# Redefining _pack_audio_latents properly for JAX
627-
@staticmethod
628-
def _pack_audio_latents_jax(
629-
latents: jax.Array, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
630602
) -> jax.Array:
631603
if patch_size is not None and patch_size_t is not None:
632604
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape

src/maxdiffusion/tests/ltx2_pipeline_test.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,39 @@ def test_check_inputs(self):
285285
negative_prompt_embeds=jnp.zeros((1, 5, 64)), # Mismatch length
286286
negative_prompt_attention_mask=jnp.ones((1, 5))
287287
)
288-
rngs = nnx.Rngs(0)
288+
289+
def test_audio_packing_unpacking(self):
290+
# (Batch, Channels, Length, Mel)
291+
batch_size = 1
292+
channels = 128
293+
length = 32
294+
mel = 64
295+
patch_size = 4
296+
patch_size_t = 1 # Audio typically has patch_size_t=1 in LTX logic, let's test that
289297

290-
pipeline = LTX2Pipeline.load_transformer(
291-
devices_array=jnp.array(jax.devices()),
292-
mesh=self.mesh,
293-
rngs=rngs,
294-
config=config,
295-
subfolder="transformer"
298+
latents = jax.random.normal(jax.random.key(0), (batch_size, channels, length, mel))
299+
300+
packed = LTX2Pipeline._pack_audio_latents(latents, patch_size=patch_size, patch_size_t=patch_size_t)
301+
302+
# Verify packed shape
303+
# original logic: (B, T', F', C, p_t, p) -> (B, T' * F', -1)
304+
# T' = 32 // 1 = 32
305+
# F' = 64 // 4 = 16
306+
# shape should be (1, 32 * 16, 128 * 1 * 4) = (1, 512, 512)
307+
expected_seq_len = (length // patch_size_t) * (mel // patch_size)
308+
expected_dim = channels * patch_size * patch_size_t
309+
self.assertEqual(packed.shape, (batch_size, expected_seq_len, expected_dim))
310+
311+
unpacked = LTX2Pipeline._unpack_audio_latents(
312+
packed,
313+
latent_length=length,
314+
num_mel_bins=mel,
315+
patch_size=patch_size,
316+
patch_size_t=patch_size_t
296317
)
297318

298-
mock_create.assert_called_once()
299-
self.assertEqual(pipeline, mock_create.return_value)
319+
self.assertEqual(unpacked.shape, latents.shape)
320+
np.testing.assert_allclose(unpacked, latents, atol=1e-6)
300321

301322
if __name__ == "__main__":
302323
unittest.main()

0 commit comments

Comments
 (0)