Skip to content

Commit 4d1775f

Browse files
set q_seq_shards=1
1 parent 7c84ec2 commit 4d1775f

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,13 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
201201
splash_kernel = splash_attention_kernel.make_splash_mha(
202202
mask=multi_head_mask,
203203
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
204-
q_seq_shards=num_fsdp_shards, # the sizes of the axis is sharding over seq_len
204+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
205205
block_sizes=block_sizes,
206206
)
207207
return splash_kernel
208208

209209
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210-
mask &= splash_attention_mask.LocalMask(
211-
shape=(query.shape[2], key.shape[2]),
212-
window_size=(query.shape[2], key.shape[2]),
213-
offset=0
214-
)
210+
215211
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
216212
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
217213
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)

0 commit comments

Comments
 (0)