@@ -637,8 +637,11 @@ def _unpack_audio_latents(
637637 if patch_size is not None and patch_size_t is not None :
638638 batch_size = latents .shape [0 ]
639639 # latents: (Batch, Seq, Dim)
640- latents = latents .reshape (batch_size , latent_length , num_mel_bins , - 1 , patch_size_t , patch_size )
641- latents = latents .transpose (0 , 3 , 1 , 4 , 2 , 5 ).reshape (batch_size , - 1 , latent_length * patch_size_t , num_mel_bins * patch_size )
640+ # Pack: (B, C, L, F) -> (B, C, L', pt, F', p) -> (B, C, L', pt, F', p) -> (B, L', F', C, pt, p) -> (B, L', F', C*pt*p)
641+ # Unpack: (B, L'*F', C*pt*p) -> (B, L', F', C, pt, p) -> (B, C, L', pt, F', p) -> (B, C, L'*pt, F'*p)
642+ latents = latents .reshape (batch_size , - 1 , num_mel_bins // patch_size , num_channels * patch_size_t * patch_size )
643+ latents = latents .reshape (batch_size , latent_length // patch_size_t , num_mel_bins // patch_size , num_channels , patch_size_t , patch_size )
644+ latents = latents .transpose (0 , 3 , 1 , 4 , 2 , 5 ).reshape (batch_size , num_channels , latent_length , num_mel_bins )
642645 # Wait, reshape order needs to match pack?
643646 # Pack: (B, C, L, F) -> (B, C, L', pt, F', p) -> (B, L', F', C, pt, p) -> (B, L'*F', C*pt*p)
644647 # Unpack: (B, L'*F', C*pt*p) -> (B, L', F', C, pt, p) -> (B, C, L', pt, F', p) -> (B, C, L'*pt, F'*p)
0 commit comments