@@ -250,11 +250,12 @@ def get_1d_rotary_pos_embed(
250250 return out
251251
252252class NNXWanImageEmbedding (nnx .Module ):
253- def __init__ (self , rngs : nnx .Rngs , in_features : int , out_features : int , dtype : jnp .dtype , weights_dtype : jnp .dtype , precision : jax .lax .Precision , pos_embed_seq_len = None , alignment : int = 128 ):
253+ def __init__ (self , rngs : nnx .Rngs , in_features : int , out_features : int , dtype : jnp .dtype , weights_dtype : jnp .dtype , precision : jax .lax .Precision , pos_embed_seq_len = None , alignment : int = 128 , flash_min_seq_length : int = 4096 ):
254254 self .norm1 = FP32LayerNorm (rngs = rngs , dim = in_features , elementwise_affine = True , eps = 1e-6 )
255255 self .ff = NNXSimpleFeedForward (rngs = rngs , dim = in_features , dim_out = out_features , mult = 1 , activation_fn = "gelu" , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
256256 self .norm2 = FP32LayerNorm (rngs = rngs , dim = out_features , elementwise_affine = True , eps = 1e-6 )
257257 self .alignment = alignment
258+ self .flash_min_seq_length = flash_min_seq_length
258259 if pos_embed_seq_len is not None :
259260 self .pos_embed = nnx .Param (jnp .zeros ((1 , pos_embed_seq_len , in_features ), dtype = dtype ))
260261 else :
@@ -277,10 +278,14 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j
277278 hidden_states = self .norm2 (hidden_states )
278279 # hidden_states shape: (B, current_seq_len, out_features)
279280 B , current_seq_len , D_out = hidden_states .shape
281+ use_flash_attn = current_seq_len >= self .flash_min_seq_length
280282
281- # --- Dynamic Padding to nearest multiple of self.alignment ---
282- num_blocks = (current_seq_len + self .alignment - 1 ) // self .alignment
283- target_seq_len = num_blocks * self .alignment
283+ if use_flash_attn :
284+ # --- Dynamic Padding to nearest multiple of self.alignment ---
285+ num_blocks = (current_seq_len + self .alignment - 1 ) // self .alignment
286+ target_seq_len = num_blocks * self .alignment
287+ else :
288+ target_seq_len = current_seq_len
284289
285290 # Create attention mask: 1 for real tokens, 0 for padded tokens
286291 attention_mask = jnp .ones ((B , current_seq_len ), dtype = jnp .int32 )
@@ -293,7 +298,8 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j
293298 # Extend mask with zeros for padded positions
294299 padding_mask = jnp .zeros ((B , padding_size ), dtype = jnp .int32 )
295300 attention_mask = jnp .concatenate ([attention_mask , padding_mask ], axis = 1 )
296-
301+ if not use_flash_attn :
302+ attention_mask = None
297303 return hidden_states , attention_mask
298304
299305
0 commit comments