@@ -402,13 +402,7 @@ def _get_gemma_prompt_embeds(
402402 text_encoder_hidden_states = jnp .array (text_encoder_hidden_states .cpu ().numpy ())
403403 prompt_attention_mask = jnp .array (prompt_attention_mask .cpu ().numpy ())
404404 else :
405- # Mock hidden states
406- # Should be removed once we have actual text_encoder ready to port
407- hidden_dim = 1024
408- num_layers = 2
409- text_encoder_hidden_states = jnp .zeros (
410- (batch_size , max_sequence_length , hidden_dim , num_layers ), dtype = dtype or jnp .float32
411- )
405+ raise ValueError ("`text_encoder` is required to encode prompts." )
412406
413407 sequence_lengths = prompt_attention_mask .sum (axis = - 1 )
414408
@@ -605,28 +599,6 @@ def _create_noised_state(
605599 @staticmethod
606600 def _pack_audio_latents (
607601 latents : jax .Array , patch_size : Optional [int ] = None , patch_size_t : Optional [int ] = None
608- ) -> jax .Array :
609- if patch_size is not None and patch_size_t is not None :
610- batch_size , num_channels , latent_length , latent_mel_bins = latents .shape
611- post_patch_latent_length = latent_length // patch_size_t
612- post_patch_mel_bins = latent_mel_bins // patch_size
613- latents = latents .reshape (
614- batch_size , - 1 , post_patch_latent_length , patch_size_t , post_patch_mel_bins , patch_size
615- )
616- latents = latents .transpose (0 , 2 , 4 , 1 , 3 , 5 ).reshape (batch_size , post_patch_latent_length * post_patch_mel_bins , - 1 )
617- else :
618- latents = latents .transpose (0 , 2 , 1 ).reshape (batch_size , latents .shape [2 ], - 1 )
619- # Wait, original was transpose(1,2).flatten(2,3) -> (Batch, Channels, Length) -> (Batch, Length, Channels)?
620- # Diffusers: latents = latents.transpose(1, 2).flatten(2, 3)
621- # (B, C, L) -> (B, L, C).
622- # If 4D: (B, C, L, M) -> (B, C, L, P_t, M, P) -> ...
623- pass
624- return latents
625-
626- # Redefining _pack_audio_latents properly for JAX
627- @staticmethod
628- def _pack_audio_latents_jax (
629- latents : jax .Array , patch_size : Optional [int ] = None , patch_size_t : Optional [int ] = None
630602 ) -> jax .Array :
631603 if patch_size is not None and patch_size_t is not None :
632604 batch_size , num_channels , latent_length , latent_mel_bins = latents .shape
0 commit comments