Skip to content

Commit 4543686

Browse files
remove heads sharding contraint after rope for seq parallelism.
1 parent ae9c952 commit 4543686

3 files changed

Lines changed: 22 additions & 23 deletions

File tree

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
BATCH = "activation_batch"
3838
LENGTH = "activation_length"
39+
KV_LENGTH = "activation_kv_length"
3940
EMBED = "activation_embed"
4041
HEAD = "activation_heads"
4142
D_KV = "activation_kv"

src/maxdiffusion/models/attention_flax.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AxisNames = common_types.AxisNames
3939
BATCH = common_types.BATCH
4040
LENGTH = common_types.LENGTH
41+
KV_LENGTH = common_types.KV_LENGTH
4142
HEAD = common_types.HEAD
4243
D_KV = common_types.D_KV
4344
EMBED = common_types.EMBED
@@ -156,7 +157,8 @@ def _tpu_flash_attention(
156157
value: jax.Array,
157158
heads: int,
158159
mesh: Mesh,
159-
flash_axis_names: AxisNames,
160+
axis_names_q: AxisNames,
161+
axis_names_kv: AxisNames,
160162
flash_block_sizes: BlockSizes,
161163
dtype: jnp.dtype = jnp.float32,
162164
) -> jax.Array:
@@ -181,8 +183,8 @@ def _tpu_flash_attention(
181183
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
182184
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
183185
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
184-
axis_names = nn.logical_to_mesh_axes(flash_axis_names)
185-
kv_axis_names = nn.logical_to_mesh_axes((BATCH, HEAD, None, D_KV))
186+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
187+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
186188
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
187189
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
188190
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
@@ -200,7 +202,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
200202
splash_kernel = splash_attention_kernel.make_splash_mha(
201203
mask=multi_head_mask,
202204
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
203-
q_seq_shards=num_fsdp_shards,
205+
q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len
204206
block_sizes=block_sizes,
205207
)
206208
return splash_kernel
@@ -213,12 +215,12 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
213215
shard_map.shard_map,
214216
mesh=mesh,
215217
in_specs=(
216-
axis_names,
218+
q_axis_names,
217219
kv_axis_names,
218220
kv_axis_names,
219221
segment_axis_names_splash_kernel,
220222
),
221-
out_specs=axis_names,
223+
out_specs=q_axis_names,
222224
check_rep=False
223225
)
224226
def wrap_flash_attention(query, key, value, splash_kernel):
@@ -359,7 +361,8 @@ def _apply_attention(
359361
scale: float,
360362
dtype: jnp.dtype,
361363
mesh: Mesh,
362-
flash_axis_names: AxisNames,
364+
axis_names_q: AxisNames,
365+
axis_names_kv: AxisNames,
363366
flash_block_sizes: BlockSizes,
364367
dpa_layer: Callable,
365368
):
@@ -382,7 +385,7 @@ def _apply_attention(
382385
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
383386
)
384387
elif attention_kernel == "flash":
385-
return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype)
388+
return _tpu_flash_attention(query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype)
386389
elif attention_kernel == "cudnn_flash_te":
387390
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
388391
else:
@@ -505,7 +508,8 @@ def __init__(
505508
use_memory_efficient_attention: bool = False,
506509
split_head_dim: bool = False,
507510
float32_qk_product: bool = True,
508-
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
511+
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
512+
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
509513
flash_min_seq_length: int = 4096,
510514
flash_block_sizes: BlockSizes = None,
511515
dtype: DType = jnp.float32,
@@ -523,7 +527,8 @@ def __init__(
523527
self.use_memory_efficient_attention = use_memory_efficient_attention
524528
self.split_head_dim = split_head_dim
525529
self.float32_qk_product = float32_qk_product
526-
self.flash_axis_names = flash_axis_names
530+
self.axis_names_q = axis_names_q
531+
self.axis_names_kv = axis_names_kv
527532
self.flash_min_seq_length = flash_min_seq_length
528533
self.flash_block_sizes = flash_block_sizes
529534
self.dtype = dtype
@@ -544,7 +549,8 @@ def apply_attention(self, query: Array, key: Array, value: Array):
544549
scale=self.scale,
545550
dtype=self.dtype,
546551
mesh=self.mesh,
547-
flash_axis_names=self.flash_axis_names,
552+
axis_names_q=self.axis_names_q,
553+
axis_names_kv=self.axis_names_kv,
548554
flash_block_sizes=self.flash_block_sizes,
549555
dpa_layer=self.dpa_layer,
550556
)
@@ -559,7 +565,8 @@ class AttentionOp(nn.Module):
559565
use_memory_efficient_attention: bool = False
560566
split_head_dim: bool = False
561567
float32_qk_product: bool = True
562-
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
568+
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
569+
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
563570
flash_min_seq_length: int = 4096
564571
flash_block_sizes: BlockSizes = None
565572
dtype: DType = jnp.float32
@@ -600,7 +607,8 @@ def apply_attention(self, query: Array, key: Array, value: Array):
600607
scale=self.scale,
601608
dtype=self.dtype,
602609
mesh=self.mesh,
603-
flash_axis_names=self.flash_axis_names,
610+
axis_names_q=self.axis_names_q,
611+
axis_names_kv=self.axis_names_kv,
604612
flash_block_sizes=self.flash_block_sizes,
605613
dpa_layer=self.dpa_layer,
606614
)
@@ -764,9 +772,6 @@ def __call__(
764772
key_proj = _unflatten_heads(key_proj, self.heads)
765773
value_proj = _unflatten_heads(value_proj, self.heads)
766774
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
767-
query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", "tensor", None, None))
768-
key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", "tensor", None, None))
769-
value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", "tensor", None, None))
770775

771776
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
772777
attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None))

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,6 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
4343
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False)
4444
freqs.append(freq)
4545
freqs = jnp.concatenate(freqs, axis=1)
46-
# sizes = jnp.array([
47-
# attention_head_dim // 2 - 2 * (attention_head_dim // 6),
48-
# attention_head_dim // 6,
49-
# attention_head_dim // 6,
50-
# ])
51-
# cumulative_sizes = jnp.cumsum(jnp.array(sizes))
52-
# split_indices = cumulative_sizes[:-1]
5346
t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6)
5447
hw_size = attention_head_dim // 6
5548

0 commit comments

Comments
 (0)