@@ -226,25 +226,26 @@ def wrap_flash_attention(query, key, value):
226226 key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_sizes .block_kv )
227227 value , _ , _ = _pad_data_for_flash (value , heads , block_sizes .block_kv )
228228
229- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
230- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
231-
232- q_padded_len = query .shape [2 ]
233- q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
234- q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
235-
236- kv_padded_len = key .shape [2 ]
237- kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
238- kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
239- segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
240- splash_kernel = splash_attention_kernel .make_splash_mha (
241- mask = multi_head_mask ,
242- head_shards = 1 , # the sizes of the axis is sharding over heads
243- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
244- block_sizes = block_sizes ,
245- save_residuals = True if attention_kernel == "ring" else False ,
246- )
247- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ), out_axes = 0 )
229+ if attention_kernel == "flash" :
230+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
231+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
232+
233+ q_padded_len = query .shape [2 ]
234+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
235+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
236+
237+ kv_padded_len = key .shape [2 ]
238+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
239+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
240+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
241+ splash_kernel = splash_attention_kernel .make_splash_mha (
242+ mask = multi_head_mask ,
243+ head_shards = 1 , # the sizes of the axis is sharding over heads
244+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
245+ block_sizes = block_sizes ,
246+ save_residuals = True if attention_kernel == "ring" else False ,
247+ )
248+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ), out_axes = 0 )
248249
249250 if attention_kernel == "flash" :
250251 # attention_output = vmapped_splash(query, key, value, segment_ids)
0 commit comments