|
16 | 16 |
|
17 | 17 | from typing import Optional, Tuple |
18 | 18 | from flax import nnx |
| 19 | +from flax import linen as nn |
19 | 20 | import jax |
20 | 21 | import jax.numpy as jnp |
21 | 22 | from ... import common_types |
@@ -485,8 +486,24 @@ def __call__( |
485 | 486 | key = apply_rotary_emb(key, rotary_emb) |
486 | 487 |
|
487 | 488 | with jax.named_scope("Attention and Output Project"): |
488 | | - # 4. Attention |
489 | | - # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel |
| 489 | + # Reshape to 4D [B, H, S, D] before passing to avoid All-Gather during transpose |
| 490 | + b, s, _ = query.shape |
| 491 | + h = self.heads |
| 492 | + d = self.dim_head |
| 493 | + |
| 494 | + query = query.reshape(b, s, h, d).transpose(0, 2, 1, 3) |
| 495 | + key = key.reshape(b, s, h, d).transpose(0, 2, 1, 3) |
| 496 | + value = value.reshape(b, s, h, d).transpose(0, 2, 1, 3) |
| 497 | + |
| 498 | + # Apply explicit sharding constraints on the 4D tensors |
| 499 | + q_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV)) |
| 500 | + kv_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV)) |
| 501 | + |
| 502 | + query = jax.lax.with_sharding_constraint(query, q_axis_names) |
| 503 | + key = jax.lax.with_sharding_constraint(key, kv_axis_names) |
| 504 | + value = jax.lax.with_sharding_constraint(value, kv_axis_names) |
| 505 | + |
| 506 | + # 4. Attention (passing 4D tensors now) |
490 | 507 | attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask) |
491 | 508 |
|
492 | 509 | # 7. Output Projection |
|
0 commit comments