Skip to content

Commit 261a492

Browse files
committed
normal vocoder+latents debug
1 parent d1f1e7c commit 261a492

1 file changed

Lines changed: 11 additions & 29 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -593,37 +593,14 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi
593593

594594
def create_model(rngs: nnx.Rngs, config: HyperParameters):
595595
if getattr(config, "model_name", "") == "ltx2.3":
596-
# Manually construct for LTX-2.3 to support BWE and avoid TypeError
597-
base_vocoder = Vocoder(
598-
upsample_initial_channel=1536,
599-
upsample_rates=(5, 2, 2, 2, 2, 2),
600-
upsample_kernel_sizes=(11, 4, 4, 4, 4, 4),
601-
use_bias_at_final=False,
602-
rngs=rngs,
603-
dtype=jnp.float32,
604-
)
605-
bwe_generator = Vocoder(
606-
upsample_initial_channel=512,
607-
upsample_kernel_sizes=[12, 11, 4, 4, 4],
608-
use_bias_at_final=False,
596+
# Force loading normal vocoder from LTX-2 for isolation
597+
vocoder = LTX2Vocoder.from_config(
598+
"Lightricks/LTX-2",
599+
subfolder="vocoder",
609600
rngs=rngs,
601+
mesh=mesh,
610602
dtype=jnp.float32,
611-
)
612-
mel_stft = MelSTFT(
613-
filter_length=512,
614-
hop_length=80,
615-
win_length=512,
616-
n_mel_channels=64,
617-
rngs=rngs,
618-
)
619-
vocoder = LTX2VocoderWithBWE(
620-
vocoder=base_vocoder,
621-
bwe_generator=bwe_generator,
622-
mel_stft=mel_stft,
623-
input_sampling_rate=16000,
624-
output_sampling_rate=48000,
625-
hop_length=80,
626-
rngs=rngs,
603+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
627604
)
628605
else:
629606
vocoder = LTX2Vocoder.from_config(
@@ -1195,7 +1172,9 @@ def prepare_latents(
11951172
# The packing and unpacking mechanisms expect (B, C, T, H, W).
11961173
latents = latents.transpose(0, 4, 1, 2, 3)
11971174

1175+
print(f"DEBUG: latents shape before pack (passed in): {latents.shape}")
11981176
latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
1177+
print(f"DEBUG: latents shape after pack (passed in): {latents.shape}")
11991178
if latents.ndim != 3:
12001179
raise ValueError("Unexpected latents shape")
12011180
latents = self._create_noised_state(latents, noise_scale, generator)
@@ -1211,7 +1190,9 @@ def prepare_latents(
12111190
generator = jax.random.key(seed)
12121191

12131192
latents = jax.random.normal(generator, shape, dtype=dtype or jnp.float32)
1193+
print(f"DEBUG: latents shape in prepare_latents before pack: {latents.shape}")
12141194
latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
1195+
print(f"DEBUG: latents shape in prepare_latents after pack: {latents.shape}")
12151196
return latents
12161197

12171198
def prepare_audio_latents(
@@ -1327,6 +1308,7 @@ def __call__(
13271308
generator=key_latents,
13281309
latents=latents,
13291310
)
1311+
print(f"DEBUG: latents shape after prepare_latents: {latents.shape}")
13301312

13311313
latent_height = height // self.vae_spatial_compression_ratio
13321314
latent_width = width // self.vae_spatial_compression_ratio

0 commit comments

Comments
 (0)