@@ -251,23 +251,44 @@ def get_1d_rotary_pos_embed(
251251
252252
253253class NNXWanImageEmbedding (nnx .Module ):
254- def __init__ (self , rngs : nnx .Rngs , in_features : int , out_features : int , dtype : jnp .dtype , weights_dtype : jnp .dtype , precision : jax .lax .Precision , pos_embed_seq_len = None ):
254+ def __init__ (self , rngs : nnx .Rngs , in_features : int , out_features : int , dtype : jnp .dtype , weights_dtype : jnp .dtype , precision : jax .lax .Precision , pos_embed_seq_len = None , alignment : int = 128 ):
255255 self .norm1 = FP32LayerNorm (rngs = rngs , dim = in_features , elementwise_affine = True , eps = 1e-6 )
256256 self .ff = NNXSimpleFeedForward (rngs = rngs , dim = in_features , dim_out = out_features , mult = 1 , activation_fn = "gelu" , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
257257 self .norm2 = FP32LayerNorm (rngs = rngs , dim = out_features , elementwise_affine = True , eps = 1e-6 )
258+ self .alignment = alignment
258259 if pos_embed_seq_len is not None :
259260 self .pos_embed = nnx .Param (jnp .zeros ((1 , pos_embed_seq_len , in_features ), dtype = dtype ))
260261 else :
261262 self .pos_embed = nnx .data (None )
262263
263264 def __call__ (self , encoder_hidden_states_image : jax .Array ) -> jax .Array :
265+ hidden_states = encoder_hidden_states_image
266+ B , current_seq_len , D_in = hidden_states .shape
267+
264268 if self .pos_embed is not None :
265- batch_size , seq_len , embed_dim = encoder_hidden_states_image .shape
266- encoder_hidden_states_image = encoder_hidden_states_image .reshape ((- 1 , 2 * seq_len , embed_dim ))
267- encoder_hidden_states_image = encoder_hidden_states_image + self .pos_embed
268- hidden_states = self .norm1 (encoder_hidden_states_image )
269+ pe_len = self .pos_embed .value .shape [1 ]
270+ add_len = min (current_seq_len , pe_len )
271+ # Apply pos_embed to the original sequence length
272+ hidden_states = hidden_states .at [:, :add_len , :].add (self .pos_embed .value [:, :add_len , :])
273+ if current_seq_len > pe_len :
274+ print (f"[WARN] Input seq_len { current_seq_len } > pos_embed len { pe_len } " )
275+
276+ hidden_states = self .norm1 (hidden_states )
269277 hidden_states = self .ff (hidden_states )
270278 hidden_states = self .norm2 (hidden_states )
279+ # hidden_states shape: (B, current_seq_len, out_features)
280+ B , current_seq_len , D_out = hidden_states .shape
281+
282+ # --- Dynamic Padding to nearest multiple of self.alignment ---
283+ num_blocks = (current_seq_len + self .alignment - 1 ) // self .alignment
284+ target_seq_len = num_blocks * self .alignment
285+
286+ if current_seq_len < target_seq_len :
287+ padding_size = target_seq_len - current_seq_len
288+ padding = jnp .zeros ((B , padding_size , D_out ), dtype = hidden_states .dtype )
289+ hidden_states = jnp .concatenate ([hidden_states , padding ], axis = 1 )
290+ print (f"[DEBUG EMB] Padded image embeds from { current_seq_len } to { target_seq_len } . New shape: { hidden_states .shape } " )
291+
271292 return hidden_states
272293
273294
0 commit comments