2626from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
2727from einops import rearrange
2828from .. import common_types , max_logging
29+ from .padded_flash_attn import make_dense_padded_attention
2930
3031from . import quantizations
3132
@@ -236,20 +237,23 @@ def wrap_flash_attention(query, key, value):
236237 kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
237238 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
238239 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
239-
240- # make_splash_mha is wrapped around shardmap and seq and head is already
241- # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
242240 splash_kernel = splash_attention_kernel .make_splash_mha (
243241 mask = multi_head_mask ,
244242 head_shards = 1 , # the sizes of the axis is sharding over heads
245243 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
246244 block_sizes = block_sizes ,
247245 save_residuals = True if attention_kernel == "ring" else False ,
248246 )
249- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
247+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ), out_axes = 0 )
250248
251249 if attention_kernel == "flash" :
250+ # attention_output = vmapped_splash(query, key, value, segment_ids)
252251 attention_output = vmapped_splash (query , key , value , segment_ids )
252+ elif attention_kernel == "dense_padded" :
253+ padded_kv_len = key .shape [1 ] - key_seq_len
254+ dense_padded_attention_kernel = make_dense_padded_attention (block_sizes = block_sizes , kv_padding = padded_kv_len )
255+ vmapped_splash = jax .vmap (dense_padded_attention_kernel , in_axes = (0 , 0 , 0 ), out_axes = 0 )
256+ attention_output , _ = vmapped_splash (query , key , value )
253257 else :
254258 if num_fsdp_shards > 1 :
255259 out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
@@ -458,6 +462,19 @@ def _apply_attention(
458462 dtype ,
459463 attention_kernel ,
460464 )
465+ elif attention_kernel == "dense_padded" :
466+ return _tpu_flash_attention (
467+ query ,
468+ key * scale ,
469+ value ,
470+ heads ,
471+ mesh ,
472+ axis_names_q ,
473+ axis_names_kv ,
474+ flash_block_sizes ,
475+ dtype ,
476+ attention_kernel ,
477+ )
461478 elif attention_kernel == "ring" :
462479 return _tpu_flash_attention (
463480 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel
@@ -877,10 +894,10 @@ def __call__(
877894 dtype = hidden_states .dtype
878895 if encoder_hidden_states is None :
879896 encoder_hidden_states = hidden_states
880-
881- query_proj = self .query (hidden_states )
882- key_proj = self .key (encoder_hidden_states )
883- value_proj = self .value (encoder_hidden_states )
897+ with jax . named_scope ( "attention-projection" ):
898+ query_proj = self .query (hidden_states )
899+ key_proj = self .key (encoder_hidden_states )
900+ value_proj = self .value (encoder_hidden_states )
884901
885902 if self .qk_norm :
886903 query_proj = self .norm_q (query_proj )
@@ -895,7 +912,8 @@ def __call__(
895912 query_proj = checkpoint_name (query_proj , "query_proj" )
896913 key_proj = checkpoint_name (key_proj , "key_proj" )
897914 value_proj = checkpoint_name (value_proj , "value_proj" )
898- attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
915+ with jax .named_scope ("attention-compute" ):
916+ attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
899917
900918 attn_output = attn_output .astype (dtype = dtype )
901919 attn_output = checkpoint_name (attn_output , "attn_output" )
0 commit comments