@@ -279,39 +279,43 @@ def wrap_flash_attention(query, key, value):
279279 block_kv_sizes += (block_sizes .block_kv_dq ,)
280280
281281 block_q = max (* block_q_sizes )
282- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
283-
284282 block_kv = max (* block_kv_sizes )
285- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
286- value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
287-
288- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
289- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
290-
291- q_padded_len = query .shape [2 ]
292- q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
293- q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
294283
295- kv_padded_len = key .shape [2 ]
296- kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
297- kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
298- segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
299-
300- # make_splash_mha is wrapped around shardmap and seq and head is already
301- # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
302284 if attention_kernel == "tokamax_flash" :
285+ # OPTIMIZATION: Skip padding and segment_ids for the optimized kernel
286+ kv_size = key .shape [- 1 ]
287+ query_seq_len = query .shape [2 ]
288+ segment_ids = None
289+
303290 mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
304291 splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
305292 mask = mask ,
306- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
293+ q_seq_shards = 1 ,
307294 config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
308- save_residuals = True if attention_kernel == "ring" else False ,
295+ save_residuals = False , # Ring attention not typically used in this path
309296 )
310297 else :
298+ # STANDARD PATH: Explicit padding (Slower)
299+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
300+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
301+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
302+
303+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
304+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
305+
306+ q_padded_len = query .shape [2 ]
307+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
308+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
309+
310+ kv_padded_len = key .shape [2 ]
311+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
312+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
313+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
314+
311315 splash_kernel = splash_attention_kernel .make_splash_mha (
312316 mask = multi_head_mask ,
313- head_shards = 1 , # the sizes of the axis is sharding over heads
314- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
317+ head_shards = 1 ,
318+ q_seq_shards = 1 ,
315319 block_sizes = block_sizes ,
316320 save_residuals = True if attention_kernel == "ring" else False ,
317321 residual_checkpoint_name = residual_checkpoint_name
0 commit comments