Skip to content

Commit 9412399

Browse files
committed
padding added for div by 128
1 parent 5631d54 commit 9412399

1 file changed

Lines changed: 26 additions & 5 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,23 +251,44 @@ def get_1d_rotary_pos_embed(
251251

252252

253253
class NNXWanImageEmbedding(nnx.Module):
254-
def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None):
254+
def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, alignment: int = 128):
255255
self.norm1 = FP32LayerNorm(rngs=rngs, dim=in_features, elementwise_affine=True, eps=1e-6)
256256
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=in_features, dim_out=out_features, mult=1, activation_fn="gelu", dtype=dtype, weights_dtype=weights_dtype, precision=precision)
257257
self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6)
258+
self.alignment = alignment
258259
if pos_embed_seq_len is not None:
259260
self.pos_embed = nnx.Param(jnp.zeros((1, pos_embed_seq_len, in_features), dtype=dtype))
260261
else:
261262
self.pos_embed = nnx.data(None)
262263

263264
def __call__(self, encoder_hidden_states_image: jax.Array) -> jax.Array:
265+
hidden_states = encoder_hidden_states_image
266+
B, current_seq_len, D_in = hidden_states.shape
267+
264268
if self.pos_embed is not None:
265-
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
266-
encoder_hidden_states_image = encoder_hidden_states_image.reshape((-1, 2 * seq_len, embed_dim))
267-
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
268-
hidden_states = self.norm1(encoder_hidden_states_image)
269+
pe_len = self.pos_embed.value.shape[1]
270+
add_len = min(current_seq_len, pe_len)
271+
# Apply pos_embed to the original sequence length
272+
hidden_states = hidden_states.at[:, :add_len, :].add(self.pos_embed.value[:, :add_len, :])
273+
if current_seq_len > pe_len:
274+
print(f"[WARN] Input seq_len {current_seq_len} > pos_embed len {pe_len}")
275+
276+
hidden_states = self.norm1(hidden_states)
269277
hidden_states = self.ff(hidden_states)
270278
hidden_states = self.norm2(hidden_states)
279+
# hidden_states shape: (B, current_seq_len, out_features)
280+
B, current_seq_len, D_out = hidden_states.shape
281+
282+
# --- Dynamic Padding to nearest multiple of self.alignment ---
283+
num_blocks = (current_seq_len + self.alignment - 1) // self.alignment
284+
target_seq_len = num_blocks * self.alignment
285+
286+
if current_seq_len < target_seq_len:
287+
padding_size = target_seq_len - current_seq_len
288+
padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype)
289+
hidden_states = jnp.concatenate([hidden_states, padding], axis=1)
290+
print(f"[DEBUG EMB] Padded image embeds from {current_seq_len} to {target_seq_len}. New shape: {hidden_states.shape}")
291+
271292
return hidden_states
272293

273294

0 commit comments

Comments
 (0)