Skip to content

Commit 9651aeb

Browse files
committed
sharding k,v,q across context by fix in attwntion_ltx2.py
1 parent 50160c6 commit 9651aeb

2 files changed

Lines changed: 27 additions & 28 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,21 +267,38 @@ def _tpu_flash_attention(
267267
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
268268
)
269269
num_context_shards = mesh.shape["context"]
270-
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
271-
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
272-
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
273-
274-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
275-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
270+
def _pad_3d(tensor, num_shards):
271+
org_len = tensor.shape[1]
272+
rem = org_len % num_shards
273+
if rem == 0:
274+
return tensor, org_len
275+
pad_width = [(0, 0)] * tensor.ndim
276+
pad_width[1] = (0, num_shards - rem)
277+
return jnp.pad(tensor, pad_width), org_len
278+
279+
query, orig_q_seq_len = _pad_3d(query, num_context_shards)
280+
key, _ = _pad_3d(key, num_context_shards)
281+
value, _ = _pad_3d(value, num_context_shards)
282+
283+
# Define 3D sharding specs (Batch, Seq, None)
284+
q_axis_names_3d = nn.logical_to_mesh_axes((axis_names_q[0], axis_names_q[2], None))
285+
kv_axis_names_3d = nn.logical_to_mesh_axes((axis_names_kv[0], axis_names_kv[2], None))
286+
287+
# Output spec is still 4D [Batch, Heads, Seq, HeadDim]
288+
q_axis_names_4d = nn.logical_to_mesh_axes(axis_names_q)
276289

277290
@functools.partial(
278291
shard_map.shard_map,
279292
mesh=mesh,
280-
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
281-
out_specs=q_axis_names,
293+
in_specs=(q_axis_names_3d, kv_axis_names_3d, kv_axis_names_3d),
294+
out_specs=q_axis_names_4d,
282295
check_rep=False,
283296
)
284297
def wrap_flash_attention(query, key, value):
298+
# Reshape to 4D inside shard_map to avoid All-Gather during transpose
299+
query = _unflatten_heads(query, heads)
300+
key = _unflatten_heads(key, heads)
301+
value = _unflatten_heads(value, heads)
285302
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
286303
block_q_sizes = (
287304
block_sizes.block_q,

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from typing import Optional, Tuple
1818
from flax import nnx
19-
from flax import linen as nn
2019
import jax
2120
import jax.numpy as jnp
2221
from ... import common_types
@@ -486,25 +485,8 @@ def __call__(
486485
key = apply_rotary_emb(key, rotary_emb)
487486

488487
with jax.named_scope("Attention and Output Project"):
489-
# Reshape to 4D [B, H, S, D] before passing to avoid All-Gather during transpose
490-
b, s_q, _ = query.shape
491-
_, s_kv, _ = key.shape
492-
h = self.heads
493-
d = self.dim_head
494-
495-
query = query.reshape(b, s_q, h, d).transpose(0, 2, 1, 3)
496-
key = key.reshape(b, s_kv, h, d).transpose(0, 2, 1, 3)
497-
value = value.reshape(b, s_kv, h, d).transpose(0, 2, 1, 3)
498-
499-
# Apply explicit sharding constraints on the 4D tensors
500-
q_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV))
501-
kv_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV))
502-
503-
query = jax.lax.with_sharding_constraint(query, q_axis_names)
504-
key = jax.lax.with_sharding_constraint(key, kv_axis_names)
505-
value = jax.lax.with_sharding_constraint(value, kv_axis_names)
506-
507-
# 4. Attention (passing 4D tensors now)
488+
# 4. Attention
489+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
508490
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
509491

510492
# 7. Output Projection

0 commit comments

Comments
 (0)