2424from jax .experimental import shard_map
2525from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
2626from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
27+ from jax .experimental .pallas .ops .tpu import flash_attention
28+ from jax .experimental .pallas .ops .tpu .flash_attention import flash_attention as jax_flash_attention
2729from einops import rearrange
2830from .. import common_types , max_logging
2931
@@ -103,7 +105,7 @@ def _unflatten_heads(tensor, heads):
103105 tensor = jnp .transpose (tensor , (0 , 2 , 1 , 3 ))
104106 return tensor
105107
106- def _reshape_data_for_flash (tensor , heads , flash_block_size ):
108+ def _reshape_data_for_flash (tensor , heads , flash_block_size , pad = True ):
107109 """
108110 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
109111 """
@@ -127,12 +129,28 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size):
127129 # pad to the closest multiplier of flash_block_size
128130 seq_len_pad = (mul + 1 ) * flash_block_size - seq_len
129131
130- if kv_size < 128 or rem != 0 :
132+ if kv_size < 128 or rem != 0 and pad :
131133 npad = ((0 , 0 ), (0 , 0 ), (0 , seq_len_pad ), (0 , head_dim_pad ))
132134 tensor = jnp .pad (tensor , npad )
133135
134136 return tensor , kv_size , seq_len
135137
138+ def default_block_sizes (query : jax .Array , key : jax .Array , dtype : jnp .dtype ):
139+ max_block_size = 1024 if dtype == jnp .bfloat16 else 512
140+ return flash_attention .BlockSizes (
141+ block_q = min (max_block_size , query .shape [- 2 ]),
142+ block_k_major = min (max_block_size , key .shape [- 2 ]),
143+ block_k = min (max_block_size , key .shape [- 2 ]),
144+ block_b = min (1 , query .shape [0 ]),
145+ block_q_major_dkv = min (max_block_size , query .shape [- 2 ]),
146+ block_k_major_dkv = min (max_block_size , key .shape [- 2 ]),
147+ block_q_dkv = min (max_block_size , query .shape [- 2 ]),
148+ block_k_dkv = min (max_block_size , key .shape [- 2 ]),
149+ block_q_dq = min (max_block_size , query .shape [- 2 ]),
150+ block_k_dq = min (512 , key .shape [- 2 ]),
151+ block_k_major_dq = min (max_block_size , key .shape [- 2 ]),
152+ )
153+
136154def _tpu_flash_attention (
137155 query : jax .Array ,
138156 key : jax .Array ,
@@ -144,6 +162,59 @@ def _tpu_flash_attention(
144162 dtype : jnp .dtype = jnp .float32 ) -> jax .Array :
145163 """TPU Flash Attention"""
146164
165+ if flash_block_sizes is None :
166+ block_sizes = default_block_sizes (query , key , dtype )
167+
168+ query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , pad = True )
169+ key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_k , pad = True )
170+ value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_k , pad = True )
171+
172+ axis_names = nn .logical_to_mesh_axes (flash_axis_names )
173+ @functools .partial (
174+ shard_map .shard_map ,
175+ mesh = mesh ,
176+ in_specs = (
177+ axis_names ,
178+ axis_names ,
179+ axis_names ,
180+ ),
181+ out_specs = axis_names ,
182+ check_rep = False ,
183+ )
184+ def wrap_flash_attention (query , key , value ):
185+ output = jax_flash_attention (
186+ query ,
187+ key ,
188+ value ,
189+ block_sizes = block_sizes
190+ )
191+ return output
192+
193+ devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
194+ # This warning might show up when doing model eval for example, when calculating model flops
195+ # and that is expected.
196+ if not (query .shape [0 ] / devices_in_data_fsdp ).is_integer ():
197+ max_logging .log (
198+ "Warning, batch dimension should be shardable among the devices in data and fsdp"
199+ f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
200+ )
201+ x = wrap_flash_attention (query , key , value )
202+ x = x [:, :, :query_seq_len , :kv_size ]
203+ x = _reshape_heads_to_head_dim (x )
204+ return x
205+
206+
207+ def _tpu_splash_attention (
208+ query : jax .Array ,
209+ key : jax .Array ,
210+ value : jax .Array ,
211+ heads : int ,
212+ mesh : Mesh ,
213+ flash_axis_names : AxisNames ,
214+ flash_block_sizes : BlockSizes ,
215+ dtype : jnp .dtype = jnp .float32 ) -> jax .Array :
216+ """TPU Flash Attention"""
217+
147218 max_block_size = 1024 if dtype == jnp .bfloat16 else 512
148219 if flash_block_sizes :
149220 block_sizes = flash_block_sizes
@@ -182,6 +253,7 @@ def wrap_flash_attention(query, key, value):
182253 splash_kernel = splash_attention_kernel .make_splash_mha (
183254 mask = multi_head_mask , head_shards = 1 , q_seq_shards = 1 , block_sizes = block_sizes
184255 )
256+ #return splash_kernel(query, key, value)
185257 return jax .vmap (splash_kernel )(query , key , value )
186258
187259 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
@@ -344,6 +416,8 @@ def _apply_attention(
344416 return _apply_attention_dot (query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention )
345417 elif attention_kernel == "flash" :
346418 return _tpu_flash_attention (query , key * scale , value , heads , mesh , flash_axis_names , flash_block_sizes , dtype )
419+ elif attention_kernel == "splash" :
420+ return _tpu_splash_attention (query , key * scale , value , heads , mesh , flash_axis_names , flash_block_sizes , dtype )
347421 elif attention_kernel == "cudnn_flash_te" :
348422 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
349423 else :
0 commit comments