Skip to content

Commit 3839870

Browse files
committed
Fix transformer sharding, flash block sizing, and tests
1 parent 384d211 commit 3839870

3 files changed

Lines changed: 119 additions & 38 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,49 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
190190
return tensor, kv_size, seq_len
191191

192192

193+
def _flash_sequence_length(tensor: Array) -> int:
194+
if tensor.ndim == 3:
195+
return tensor.shape[1]
196+
if tensor.ndim == 4:
197+
return tensor.shape[2]
198+
raise ValueError(f"Flash attention expects rank-3 or rank-4 inputs, got rank {tensor.ndim}.")
199+
200+
201+
def _select_flash_block_sizes(
202+
query: Array,
203+
key: Array,
204+
flash_block_sizes: BlockSizes,
205+
dtype: jnp.dtype,
206+
attention_kernel: str,
207+
) -> BlockSizes:
208+
query_seq_len = _flash_sequence_length(query)
209+
key_seq_len = _flash_sequence_length(key)
210+
211+
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
212+
if key_seq_len != query_seq_len:
213+
kv_max_block_size = ((key_seq_len + 127) // 128) * 128
214+
else:
215+
kv_max_block_size = q_max_block_size
216+
217+
# keep configured block sizes for self-attention, but let
218+
# cross-attention derive safe KV-aware sizes when q_len != kv_len.
219+
if flash_block_sizes and key_seq_len == query_seq_len:
220+
return flash_block_sizes
221+
222+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
223+
return splash_attention_kernel.BlockSizes(
224+
block_q=block_size_q,
225+
block_kv_compute=min(kv_max_block_size, key_seq_len),
226+
block_kv=min(kv_max_block_size, key_seq_len),
227+
block_q_dkv=block_size_q,
228+
block_kv_dkv=min(kv_max_block_size, key_seq_len),
229+
block_kv_dkv_compute=min(kv_max_block_size, query_seq_len),
230+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
231+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len),
232+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
233+
)
234+
235+
193236
def convert_to_tokamax_splash_config(
194237
block_sizes: BlockSizes,
195238
q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
@@ -244,28 +287,7 @@ def _tpu_flash_attention(
244287
) -> jax.Array:
245288
"""TPU Flash Attention"""
246289

247-
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
248-
# This is the case for cross-attn.
249-
if key.shape[1] != query.shape[1]:
250-
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
251-
else:
252-
kv_max_block_size = q_max_block_size
253-
# ensure that for cross attention we override the block sizes.
254-
if flash_block_sizes and key.shape[1] == query.shape[1]:
255-
block_sizes = flash_block_sizes
256-
else:
257-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
258-
block_sizes = splash_attention_kernel.BlockSizes(
259-
block_q=block_size_q,
260-
block_kv_compute=min(kv_max_block_size, key.shape[2]),
261-
block_kv=min(kv_max_block_size, key.shape[2]),
262-
block_q_dkv=block_size_q,
263-
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
264-
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
265-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
266-
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
267-
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
268-
)
290+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
269291
num_context_shards = mesh.shape["context"]
270292
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
271293
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
@@ -979,7 +1001,7 @@ def __init__(
9791001
precision=precision,
9801002
bias_init=nnx.with_partitioning(
9811003
nnx.initializers.zeros,
982-
("embed",),
1004+
("heads",),
9831005
),
9841006
)
9851007

@@ -993,7 +1015,7 @@ def __init__(
9931015
precision=precision,
9941016
bias_init=nnx.with_partitioning(
9951017
nnx.initializers.zeros,
996-
("embed",),
1018+
("heads",),
9971019
),
9981020
)
9991021

@@ -1007,7 +1029,7 @@ def __init__(
10071029
precision=precision,
10081030
bias_init=nnx.with_partitioning(
10091031
nnx.initializers.zeros,
1010-
("embed",),
1032+
("heads",),
10111033
),
10121034
)
10131035

@@ -1021,7 +1043,7 @@ def __init__(
10211043
precision=precision,
10221044
bias_init=nnx.with_partitioning(
10231045
nnx.initializers.zeros,
1024-
("heads",),
1046+
("embed",),
10251047
),
10261048
)
10271049

@@ -1263,8 +1285,7 @@ def __call__(
12631285

12641286
with jax.named_scope("proj_attn"):
12651287
hidden_states = self.proj_attn(attn_output)
1266-
if self.drop_out.rate > 0:
1267-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1288+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
12681289
return hidden_states
12691290

12701291

@@ -1333,11 +1354,13 @@ def setup(self):
13331354
precision=self.precision,
13341355
)
13351356

1357+
proj_attn_kernel_axes = ("heads", "embed")
1358+
13361359
self.proj_attn = nn.Dense(
13371360
self.query_dim,
1338-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
1361+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
13391362
use_bias=True,
1340-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
1363+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
13411364
dtype=self.dtype,
13421365
param_dtype=self.weights_dtype,
13431366
name="i_proj",
@@ -1346,9 +1369,9 @@ def setup(self):
13461369

13471370
self.encoder_proj_attn = nn.Dense(
13481371
self.query_dim,
1349-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
1372+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
13501373
use_bias=True,
1351-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
1374+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
13521375
dtype=self.dtype,
13531376
param_dtype=self.weights_dtype,
13541377
name="e_proj",

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ def __init__(
193193
kernel_init=nnx.with_partitioning(
194194
nnx.initializers.xavier_uniform(),
195195
(
196-
"mlp",
197196
"embed",
197+
"mlp",
198198
),
199199
),
200-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
200+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
201201
)
202202

203203
def __call__(self, x: jax.Array) -> jax.Array:
@@ -249,8 +249,8 @@ def __init__(
249249
kernel_init=nnx.with_partitioning(
250250
nnx.initializers.xavier_uniform(),
251251
(
252+
"mlp",
252253
"embed",
253-
"mlp",
254254
),
255255
),
256256
)
@@ -262,8 +262,7 @@ def conditional_named_scope(self, name: str):
262262
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
263263
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
264264
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
265-
if self.drop_out.rate > 0:
266-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
265+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
267266
with jax.named_scope("proj_out"):
268267
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
269268

src/maxdiffusion/tests/attention_test.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import jax
2121
from jax.sharding import Mesh
2222
import jax.numpy as jnp
23-
from ..models.attention_flax import FlaxAttention
23+
from ..common_types import BlockSizes
2424
from .. import max_utils
25+
from ..models.attention_flax import FlaxAttention, _select_flash_block_sizes
2526
from .. import pyconfig
2627

2728
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -35,6 +36,8 @@ def setUp(self):
3536

3637
def test_splash_attention(self):
3738
"""Test numerics of splash attention are equivalent to dot_product"""
39+
if jax.devices()[0].platform != "tpu":
40+
self.skipTest("TPU splash attention test requires a TPU backend.")
3841

3942
pyconfig.initialize(
4043
[
@@ -92,6 +95,62 @@ def test_splash_attention(self):
9295

9396
assert diff_norm < 1.0
9497

98+
def test_select_flash_block_sizes_keeps_self_attention_config(self):
99+
flash_block_sizes = BlockSizes(
100+
block_q=2048,
101+
block_kv=1024,
102+
block_kv_compute=1024,
103+
block_q_dkv=2048,
104+
block_kv_dkv=1024,
105+
block_kv_dkv_compute=1024,
106+
block_q_dq=2048,
107+
block_kv_dq=1024,
108+
)
109+
110+
query = jnp.ones((1, 4096, 128), dtype=jnp.bfloat16)
111+
key = jnp.ones((1, 4096, 128), dtype=jnp.bfloat16)
112+
113+
selected = _select_flash_block_sizes(query, key, flash_block_sizes, jnp.bfloat16, "flash")
114+
115+
self.assertEqual(selected, flash_block_sizes)
116+
117+
def test_select_flash_block_sizes_overrides_cross_attention_kv_blocks(self):
118+
flash_block_sizes = BlockSizes(
119+
block_q=2048,
120+
block_kv=2048,
121+
block_kv_compute=1024,
122+
block_q_dkv=2048,
123+
block_kv_dkv=2048,
124+
block_kv_dkv_compute=1024,
125+
block_q_dq=2048,
126+
block_kv_dq=1024,
127+
)
128+
129+
query = jnp.ones((1, 4096, 128), dtype=jnp.bfloat16)
130+
key = jnp.ones((1, 512, 128), dtype=jnp.bfloat16)
131+
132+
selected = _select_flash_block_sizes(query, key, flash_block_sizes, jnp.bfloat16, "flash")
133+
134+
self.assertEqual(selected.block_q, flash_block_sizes.block_q)
135+
self.assertEqual(selected.block_q_dkv, flash_block_sizes.block_q)
136+
self.assertEqual(selected.block_q_dq, flash_block_sizes.block_q)
137+
self.assertEqual(selected.block_kv, 512)
138+
self.assertEqual(selected.block_kv_compute, 512)
139+
self.assertEqual(selected.block_kv_dkv, 512)
140+
self.assertEqual(selected.block_kv_dkv_compute, 512)
141+
self.assertEqual(selected.block_kv_dq, 512)
142+
143+
def test_select_flash_block_sizes_uses_sequence_axis_for_rank_4_inputs(self):
144+
query = jnp.ones((1, 4, 4096, 128), dtype=jnp.bfloat16)
145+
key = jnp.ones((1, 4, 512, 128), dtype=jnp.bfloat16)
146+
147+
selected = _select_flash_block_sizes(query, key, None, jnp.bfloat16, "flash")
148+
149+
self.assertEqual(selected.block_q, 1024)
150+
self.assertEqual(selected.block_kv, 512)
151+
self.assertEqual(selected.block_kv_compute, 512)
152+
self.assertEqual(selected.block_kv_dkv_compute, 512)
153+
95154

96155
if __name__ == "__main__":
97156
absltest.main()

0 commit comments

Comments
 (0)