@@ -587,6 +587,7 @@ def __call__(
587587 timestep : jax .Array ,
588588 encoder_hidden_states : jax .Array ,
589589 encoder_hidden_states_image : Optional [jax .Array ] = None ,
590+ encoder_attention_mask : Optional [jax .Array ] = None ,
590591 return_dict : bool = True ,
591592 attention_kwargs : Optional [Dict [str , Any ]] = None ,
592593 deterministic : bool = True ,
@@ -606,17 +607,30 @@ def __call__(
606607 hidden_states = self .patch_embedding (hidden_states )
607608 hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
608609 with self .conditional_named_scope ("condition_embedder" ):
609- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image , encoder_attention_mask = self .condition_embedder (
610+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image , image_attention_mask = self .condition_embedder (
610611 timestep , encoder_hidden_states , encoder_hidden_states_image
611612 )
612613 timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
613614
615+ # Handle masks for T2V vs I2V
614616 if encoder_hidden_states_image is not None :
617+ # I2V case: concatenate image and text embeddings
615618 encoder_hidden_states = jnp .concatenate ([encoder_hidden_states_image , encoder_hidden_states ], axis = 1 )
616- if encoder_attention_mask is not None :
617- text_mask = jnp .ones ((encoder_hidden_states .shape [0 ], encoder_hidden_states .shape [1 ] - encoder_hidden_states_image .shape [1 ]), dtype = jnp .int32 )
618- encoder_attention_mask = jnp .concatenate ([encoder_attention_mask , text_mask ], axis = 1 )
619+
620+ # Build combined mask: [image_mask | text_mask]
621+ if image_attention_mask is not None :
622+ # We have image mask from embedder
623+ if encoder_attention_mask is not None :
624+ # Use passed text mask (from pipeline)
625+ combined_mask = jnp .concatenate ([image_attention_mask , encoder_attention_mask ], axis = 1 )
626+ else :
627+ # No text mask passed, use all-ones (old behavior for backward compat)
628+ text_len = encoder_hidden_states .shape [1 ] - image_attention_mask .shape [1 ]
629+ text_mask = jnp .ones ((encoder_hidden_states .shape [0 ], text_len ), dtype = jnp .int32 )
630+ combined_mask = jnp .concatenate ([image_attention_mask , text_mask ], axis = 1 )
631+ encoder_attention_mask = combined_mask
619632 encoder_hidden_states = encoder_hidden_states .astype (hidden_states .dtype )
633+ # For T2V: encoder_attention_mask is already the text mask passed from pipeline
620634
621635 if self .scan_layers :
622636
0 commit comments