Skip to content

Commit 4475978

Browse files
committed
trying dot attn fix
1 parent 13048e7 commit 4475978

3 files changed

Lines changed: 17 additions & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,9 @@ def __call__(
11271127
image_seq_len_actual = 257
11281128
padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384
11291129

1130+
if encoder_attention_mask is None:
1131+
padded_img_len = image_seq_len_actual
1132+
11301133
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
11311134
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
11321135

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,12 @@ def get_1d_rotary_pos_embed(
250250
return out
251251

252252
class NNXWanImageEmbedding(nnx.Module):
253-
def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, alignment: int = 128):
253+
def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, alignment: int = 128, flash_min_seq_length: int = 4096):
254254
self.norm1 = FP32LayerNorm(rngs=rngs, dim=in_features, elementwise_affine=True, eps=1e-6)
255255
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=in_features, dim_out=out_features, mult=1, activation_fn="gelu", dtype=dtype, weights_dtype=weights_dtype, precision=precision)
256256
self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6)
257257
self.alignment = alignment
258+
self.flash_min_seq_length = flash_min_seq_length
258259
if pos_embed_seq_len is not None:
259260
self.pos_embed = nnx.Param(jnp.zeros((1, pos_embed_seq_len, in_features), dtype=dtype))
260261
else:
@@ -277,10 +278,14 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j
277278
hidden_states = self.norm2(hidden_states)
278279
# hidden_states shape: (B, current_seq_len, out_features)
279280
B, current_seq_len, D_out = hidden_states.shape
281+
use_flash_attn = current_seq_len>=self.flash_min_seq_length
280282

281-
# --- Dynamic Padding to nearest multiple of self.alignment ---
282-
num_blocks = (current_seq_len + self.alignment - 1) // self.alignment
283-
target_seq_len = num_blocks * self.alignment
283+
if use_flash_attn:
284+
# --- Dynamic Padding to nearest multiple of self.alignment ---
285+
num_blocks = (current_seq_len + self.alignment - 1) // self.alignment
286+
target_seq_len = num_blocks * self.alignment
287+
else:
288+
target_seq_len = current_seq_len
284289

285290
# Create attention mask: 1 for real tokens, 0 for padded tokens
286291
attention_mask = jnp.ones((B, current_seq_len), dtype=jnp.int32)
@@ -293,7 +298,8 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j
293298
# Extend mask with zeros for padded positions
294299
padding_mask = jnp.zeros((B, padding_size), dtype=jnp.int32)
295300
attention_mask = jnp.concatenate([attention_mask, padding_mask], axis=1)
296-
301+
if not use_flash_attn:
302+
attention_mask = None
297303
return hidden_states, attention_mask
298304

299305

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
dtype: jnp.dtype = jnp.float32,
105105
weights_dtype: jnp.dtype = jnp.float32,
106106
precision: jax.lax.Precision = None,
107+
flash_min_seq_length: int = 4096
107108
):
108109
self.timesteps_proj = NNXFlaxTimesteps(dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0)
109110
self.time_embedder = NNXTimestepEmbedding(
@@ -148,6 +149,7 @@ def __init__(
148149
dtype=dtype,
149150
weights_dtype=weights_dtype,
150151
precision=precision,
152+
flash_min_seq_length=flash_min_seq_length
151153
)
152154

153155
def __call__(
@@ -502,6 +504,7 @@ def __init__(
502504
text_embed_dim=text_dim,
503505
image_embed_dim=image_dim,
504506
pos_embed_seq_len=pos_embed_seq_len,
507+
flash_min_seq_length=flash_min_seq_length
505508
)
506509

507510
# 3. Transformer blocks

0 commit comments

Comments
 (0)