1- # Copyright 2025 Google LLC
1+ # Copyright 2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1717
1818import jax
1919import jax .numpy as jnp
20- from MaxText .kernels import splash_attention_kernel
20+ from maxtext .kernels . attention import splash_attention_kernel
2121
2222SegmentIds = splash_attention_kernel .SegmentIds
2323
@@ -77,38 +77,34 @@ def flash_attention_block_masked(
7777 v_head_dim_size = v .shape [- 1 ]
7878 data_type = q .dtype
7979 q_groups = num_q_heads // num_kv_heads
80- q = q .reshape ((
81- batch_size ,
82- num_kv_heads ,
83- q_groups ,
84- q_seq_len ,
85- qk_head_dim_size ,
86- ))
80+ q = q .reshape (
81+ (
82+ batch_size ,
83+ num_kv_heads ,
84+ q_groups ,
85+ q_seq_len ,
86+ qk_head_dim_size ,
87+ )
88+ )
8789
8890 # Calculate the number of key/value and query blocks.
8991 num_kv_blocks = kv_seq_len // block_kv
9092 num_q_blocks = q_seq_len // block_q
9193
9294 # Before applying the segment mask, we need to broadcast the mask in batch
9395 # dimension since we have same logic for all batches.
94- mask_full = jnp .broadcast_to (
95- mask [None , :, :], (batch_size , q_seq_len , kv_seq_len )
96- )
96+ mask_full = jnp .broadcast_to (mask [None , :, :], (batch_size , q_seq_len , kv_seq_len ))
9797
9898 if segment_ids is not None :
9999 segment_ids_q = segment_ids .q [:, :, None ]
100100 segment_ids_kv = segment_ids .kv [:, None , :]
101101 mask_full = jnp .logical_and (mask_full , segment_ids_q == segment_ids_kv )
102- mask_blocked = jax .jit (mask_blocker , static_argnums = [1 , 2 ])(
103- mask_full , block_q , block_kv
104- )
102+ mask_blocked = jax .jit (mask_blocker , static_argnums = [1 , 2 ])(mask_full , block_q , block_kv )
105103
106104 # Initialize `l` (logsumexp) and `m` (max_logits) for the online softmax.
107105 # `l` is initialized to 0 since no blocks have been processed yet and the sum
108106 # is 0.
109- l = jnp .zeros (
110- (batch_size , num_kv_heads , q_groups , q_seq_len ), dtype = data_type
111- )
107+ l = jnp .zeros ((batch_size , num_kv_heads , q_groups , q_seq_len ), dtype = data_type )
112108 # `m` is initialized to the mask_value so that the first block's maximum logit
113109 # correctly becomes the running maximum.
114110 m = jnp .full (
@@ -144,15 +140,9 @@ def inner_loop_body(i, carried_inner):
144140 # Calculates the attention computation (Q@K.T)@V with online softmax for
145141 # the current query and key/value blocks.
146142 def compute_attention_block (output , l , m ):
147- output_i_slice = jax .lax .dynamic_slice_in_dim (
148- output , i * block_q , block_q , axis = - 2
149- )
150- l_i_slice = jax .lax .dynamic_slice_in_dim (
151- l , i * block_q , block_q , axis = - 1
152- )
153- m_i_slice = jax .lax .dynamic_slice_in_dim (
154- m , i * block_q , block_q , axis = - 1
155- )
143+ output_i_slice = jax .lax .dynamic_slice_in_dim (output , i * block_q , block_q , axis = - 2 )
144+ l_i_slice = jax .lax .dynamic_slice_in_dim (l , i * block_q , block_q , axis = - 1 )
145+ m_i_slice = jax .lax .dynamic_slice_in_dim (m , i * block_q , block_q , axis = - 1 )
156146 s_i_j = jnp .einsum (
157147 "bxhqc,bxkc->bxhqk" ,
158148 q_slice ,
@@ -183,25 +173,19 @@ def compute_attention_block(output, l, m):
183173 l_i_new = m_i_difference * l_i_slice + m_i_j_difference * l_i_j
184174
185175 divider = l_i_new [..., None ]
186- numerator = l_i_slice [..., None ] * m_i_difference [
176+ numerator = l_i_slice [..., None ] * m_i_difference [..., None ] * output_i_slice + m_i_j_difference [
187177 ..., None
188- ] * output_i_slice + m_i_j_difference [..., None ] * jnp .einsum (
178+ ] * jnp .einsum (
189179 "bxhqk,bxkc->bxhqc" ,
190180 p_i_j ,
191181 v_j_slice ,
192182 preferred_element_type = data_type ,
193183 )
194184
195185 output_i_slice_new = numerator / divider
196- output = jax .lax .dynamic_update_index_in_dim (
197- output , output_i_slice_new , i * block_q , axis = - 2
198- )
199- l = jax .lax .dynamic_update_index_in_dim (
200- l , l_i_new , i * block_q , axis = - 1
201- )
202- m = jax .lax .dynamic_update_index_in_dim (
203- m , m_i_new , i * block_q , axis = - 1
204- )
186+ output = jax .lax .dynamic_update_index_in_dim (output , output_i_slice_new , i * block_q , axis = - 2 )
187+ l = jax .lax .dynamic_update_index_in_dim (l , l_i_new , i * block_q , axis = - 1 )
188+ m = jax .lax .dynamic_update_index_in_dim (m , m_i_new , i * block_q , axis = - 1 )
205189 return output , l , m
206190
207191 def identity (output , l , m ):
@@ -210,9 +194,7 @@ def identity(output, l, m):
210194 return output , l , m
211195
212196 batch_size = mask_blocked .shape [0 ]
213- mask_i_j_slice = jax .lax .dynamic_slice (
214- mask_blocked , (0 , i , j ), (batch_size , 1 , 1 )
215- )
197+ mask_i_j_slice = jax .lax .dynamic_slice (mask_blocked , (0 , i , j ), (batch_size , 1 , 1 ))
216198 # The compute_attention_block should be executed if at least one element
217199 # in the slice is non-zero, meaning at least one batch requires work for
218200 # this block.
@@ -227,15 +209,11 @@ def identity(output, l, m):
227209
228210 return output , l , m
229211
230- output , l , m = jax .lax .fori_loop (
231- 0 , num_q_blocks , inner_loop_body , (output , l , m ), unroll = True
232- )
212+ output , l , m = jax .lax .fori_loop (0 , num_q_blocks , inner_loop_body , (output , l , m ), unroll = True )
233213
234214 return (output , l , m )
235215
236- output , l , m = jax .lax .fori_loop (
237- 0 , num_kv_blocks , outer_loop_body , (output , l , m ), unroll = True
238- )
216+ output , l , m = jax .lax .fori_loop (0 , num_kv_blocks , outer_loop_body , (output , l , m ), unroll = True )
239217
240218 # Reshape the output to drop the size one dimension at index 2,
241219 # which corresponds to `num_q_heads // num_kv_heads` when
@@ -268,17 +246,11 @@ def mask_blocker(mask: jnp.ndarray, block_q: int, block_kv: int) -> jnp.ndarray:
268246 batch_size , q_seq_len , kv_seq_len = mask .shape
269247
270248 if q_seq_len % block_q != 0 :
271- raise ValueError (
272- f"q_seq_len { q_seq_len } must be divisible by block_q { block_q } "
273- )
249+ raise ValueError (f"q_seq_len { q_seq_len } must be divisible by block_q { block_q } " )
274250 if kv_seq_len % block_kv != 0 :
275- raise ValueError (
276- f"kv_seq_len { kv_seq_len } must be divisible by block_kv { block_kv } "
277- )
251+ raise ValueError (f"kv_seq_len { kv_seq_len } must be divisible by block_kv { block_kv } " )
278252 q_blocks = q_seq_len // block_q
279253 kv_blocks = kv_seq_len // block_kv
280254
281- blocked_mask = mask .reshape (
282- batch_size , q_blocks , block_q , kv_blocks , block_kv
283- )
255+ blocked_mask = mask .reshape (batch_size , q_blocks , block_q , kv_blocks , block_kv )
284256 return jnp .count_nonzero (blocked_mask , axis = (2 , 4 )).astype (jnp .int32 )
0 commit comments