@@ -593,37 +593,14 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi
593593
594594 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
595595 if getattr (config , "model_name" , "" ) == "ltx2.3" :
596- # Manually construct for LTX-2.3 to support BWE and avoid TypeError
597- base_vocoder = Vocoder (
598- upsample_initial_channel = 1536 ,
599- upsample_rates = (5 , 2 , 2 , 2 , 2 , 2 ),
600- upsample_kernel_sizes = (11 , 4 , 4 , 4 , 4 , 4 ),
601- use_bias_at_final = False ,
602- rngs = rngs ,
603- dtype = jnp .float32 ,
604- )
605- bwe_generator = Vocoder (
606- upsample_initial_channel = 512 ,
607- upsample_kernel_sizes = [12 , 11 , 4 , 4 , 4 ],
608- use_bias_at_final = False ,
596+ # Force loading normal vocoder from LTX-2 for isolation
597+ vocoder = LTX2Vocoder .from_config (
598+ "Lightricks/LTX-2" ,
599+ subfolder = "vocoder" ,
609600 rngs = rngs ,
601+ mesh = mesh ,
610602 dtype = jnp .float32 ,
611- )
612- mel_stft = MelSTFT (
613- filter_length = 512 ,
614- hop_length = 80 ,
615- win_length = 512 ,
616- n_mel_channels = 64 ,
617- rngs = rngs ,
618- )
619- vocoder = LTX2VocoderWithBWE (
620- vocoder = base_vocoder ,
621- bwe_generator = bwe_generator ,
622- mel_stft = mel_stft ,
623- input_sampling_rate = 16000 ,
624- output_sampling_rate = 48000 ,
625- hop_length = 80 ,
626- rngs = rngs ,
603+ weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
627604 )
628605 else :
629606 vocoder = LTX2Vocoder .from_config (
@@ -1195,7 +1172,9 @@ def prepare_latents(
11951172 # The packing and unpacking mechanisms expect (B, C, T, H, W).
11961173 latents = latents .transpose (0 , 4 , 1 , 2 , 3 )
11971174
1175+ print (f"DEBUG: latents shape before pack (passed in): { latents .shape } " )
11981176 latents = self ._pack_latents (latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size )
1177+ print (f"DEBUG: latents shape after pack (passed in): { latents .shape } " )
11991178 if latents .ndim != 3 :
12001179 raise ValueError ("Unexpected latents shape" )
12011180 latents = self ._create_noised_state (latents , noise_scale , generator )
@@ -1211,7 +1190,9 @@ def prepare_latents(
12111190 generator = jax .random .key (seed )
12121191
12131192 latents = jax .random .normal (generator , shape , dtype = dtype or jnp .float32 )
1193+ print (f"DEBUG: latents shape in prepare_latents before pack: { latents .shape } " )
12141194 latents = self ._pack_latents (latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size )
1195+ print (f"DEBUG: latents shape in prepare_latents after pack: { latents .shape } " )
12151196 return latents
12161197
12171198 def prepare_audio_latents (
@@ -1327,6 +1308,7 @@ def __call__(
13271308 generator = key_latents ,
13281309 latents = latents ,
13291310 )
1311+ print (f"DEBUG: latents shape after prepare_latents: { latents .shape } " )
13301312
13311313 latent_height = height // self .vae_spatial_compression_ratio
13321314 latent_width = width // self .vae_spatial_compression_ratio
0 commit comments