@@ -281,45 +281,47 @@ def wrap_flash_attention(query, key, value):
281281 block_q = max (* block_q_sizes )
282282 block_kv = max (* block_kv_sizes )
283283
284+ # FIX: Always pad data. The kernel requires seq_len % block_size == 0.
285+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
286+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
287+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
288+
284289 if attention_kernel == "tokamax_flash" :
285- # OPTIMIZATION: Skip padding and segment_ids for the optimized kernel
286- kv_size = key .shape [- 1 ]
287- query_seq_len = query .shape [2 ]
288- segment_ids = None
290+ # OPTIMIZATION: We pad the data (required), but we skip
291+ # calculating 'segment_ids' (overhead), relying on the kernel's
292+ # internal masking for the padded regions.
293+ segment_ids = None
294+
295+ mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
296+ splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
297+ mask = mask ,
298+ q_seq_shards = 1 ,
299+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
300+ save_residuals = False ,
301+ )
289302
290- mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
291- splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
292- mask = mask ,
293- q_seq_shards = 1 ,
294- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
295- save_residuals = False , # Ring attention not typically used in this path
296- )
297303 else :
298- # STANDARD PATH: Explicit padding (Slower)
299- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
300- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
301- value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
302-
303- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
304- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
305-
306- q_padded_len = query .shape [2 ]
307- q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
308- q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
309-
310- kv_padded_len = key .shape [2 ]
311- kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
312- kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
313- segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
314-
315- splash_kernel = splash_attention_kernel .make_splash_mha (
316- mask = multi_head_mask ,
317- head_shards = 1 ,
318- q_seq_shards = 1 ,
319- block_sizes = block_sizes ,
320- save_residuals = True if attention_kernel == "ring" else False ,
321- residual_checkpoint_name = residual_checkpoint_name
322- )
304+ # STANDARD PATH: Explicit Padding + Segment IDs
305+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
306+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
307+
308+ q_padded_len = query .shape [2 ]
309+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
310+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
311+
312+ kv_padded_len = key .shape [2 ]
313+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
314+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
315+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
316+
317+ splash_kernel = splash_attention_kernel .make_splash_mha (
318+ mask = multi_head_mask ,
319+ head_shards = 1 ,
320+ q_seq_shards = 1 ,
321+ block_sizes = block_sizes ,
322+ save_residuals = True if attention_kernel == "ring" else False ,
323+ residual_checkpoint_name = residual_checkpoint_name
324+ )
323325 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
324326
325327 if not mask_padding_tokens :
0 commit comments