@@ -112,15 +112,23 @@ def _unflatten_heads(tensor, heads):
112112 tensor = jnp .transpose (tensor , (0 , 2 , 1 , 3 ))
113113 return tensor
114114
115-
116- def _reshape_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
115+ def _reshape_data_for_flash (tensor , heads ):
117116 """
118117 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
119118 Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
120119 blocks is divisible by the number of shards.
121120 """
122121 if tensor .ndim != 4 :
123122 tensor = _unflatten_heads (tensor , heads )
123+ return tensor
124+
125+ def _pad_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
126+ """
127+ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
128+ Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
129+ blocks is divisible by the number of shards.
130+ """
131+ tensor = _reshape_data_for_flash (tensor , heads )
124132
125133 # Pad head_dim to 128 if less than that.
126134 kv_size = tensor .shape [- 1 ]
@@ -148,8 +156,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
148156
149157 if kv_size < 128 or seq_len_pad != 0 :
150158 npad = ((0 , 0 ), (0 , 0 ), (0 , seq_len_pad ), (0 , head_dim_pad ))
151- padded_tensor = jnp .pad (tensor , npad )
152- tensor = jax .lax .with_sharding_constraint (padded_tensor , PartitionSpec ("data" , "tensor" , "fsdp" , None ))
159+ tensor = jnp .pad (tensor , npad )
153160
154161 return tensor , kv_size , seq_len
155162
@@ -166,11 +173,13 @@ def _tpu_flash_attention(
166173 dtype : jnp .dtype = jnp .float32 ,
167174) -> jax .Array :
168175 """TPU Flash Attention"""
176+
169177 q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
170- # Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171- # If kv seq_len is padded too much, it causes issues in attention calculations.
178+ # This is the case for cross-attn.
172179 if key .shape [1 ] != query .shape [1 ]:
180+ assert key .shape [1 ] % 128 == 0
173181 kv_max_block_size = key .shape [1 ]
182+ #q_max_block_size = kv_max_block_size
174183 else :
175184 kv_max_block_size = q_max_block_size
176185 if flash_block_sizes :
@@ -186,43 +195,44 @@ def _tpu_flash_attention(
186195 block_q_dq = min (q_max_block_size , query .shape [2 ]),
187196 block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
188197 )
189-
190- num_fsdp_shards = mesh .shape ["fsdp" ]
191- query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , num_fsdp_shards )
192- key , _ , key_seq_len = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
193- value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
198+
199+ query = _reshape_data_for_flash (query , heads )
200+ key = _reshape_data_for_flash (key , heads )
201+ value = _reshape_data_for_flash (value , heads )
194202 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
195203 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
196204
197- # To only attend to non-padded tokens.
198- segment_axis_names_q = nn .logical_to_mesh_axes ((BATCH , LENGTH ))
199- segment_axis_names_kv = nn .logical_to_mesh_axes ((BATCH , KV_LENGTH ))
200- q_segment_ids = jnp .where (jnp .arange (query .shape [2 ]) < query_seq_len , 1 , 0 )
201- q_segment_ids = jnp .broadcast_to (q_segment_ids , (query .shape [0 ], q_segment_ids .shape [0 ]))
202- kv_segment_ids = jnp .where (jnp .arange (key .shape [2 ]) < key_seq_len , 1 , 0 )
203- kv_segment_ids = jnp .broadcast_to (kv_segment_ids , (query .shape [0 ], kv_segment_ids .shape [0 ]))
204-
205205 @functools .partial (
206206 shard_map .shard_map ,
207207 mesh = mesh ,
208- in_specs = (q_axis_names , kv_axis_names , kv_axis_names , segment_axis_names_q , segment_axis_names_kv ),
208+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
209209 out_specs = q_axis_names ,
210210 check_rep = False ,
211211 )
212- def wrap_flash_attention (query , key , value , q_segment_ids , kv_segment_ids ):
212+ def wrap_flash_attention (query , key , value ):
213+
214+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_sizes .block_q )
215+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_sizes .block_kv_compute )
216+ value , _ , _ = _pad_data_for_flash (value , heads , block_sizes .block_kv_compute )
217+
213218 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
214219 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
220+ q_segment_ids = jnp .where (jnp .arange (query .shape [2 ]) < query_seq_len , 1 , 0 )
221+ q_segment_ids = jnp .broadcast_to (q_segment_ids , (query .shape [0 ], q_segment_ids .shape [0 ]))
222+ kv_segment_ids = jnp .where (jnp .arange (key .shape [2 ]) < key_seq_len , 1 , 0 )
223+ kv_segment_ids = jnp .broadcast_to (kv_segment_ids , (query .shape [0 ], kv_segment_ids .shape [0 ]))
224+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
225+
215226 # make_splash_mha is wrapped around shardmap and seq and head is already
216227 # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
217- segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
218228 splash_kernel = splash_attention_kernel .make_splash_mha (
219229 mask = multi_head_mask ,
220230 head_shards = 1 , # the sizes of the axis is sharding over heads
221231 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
222232 block_sizes = block_sizes ,
223233 )
224234 attention_output = jax .vmap (splash_kernel )(query , key , value , segment_ids = segment_ids )
225- return attention_output
235+ return attention_output [:,:,: query_seq_len ,: kv_size ]
226236
227237 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
228238 # This warning might show up when doing model eval for example, when calculating model flops
@@ -232,8 +242,7 @@ def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids):
232242 "Warning, batch dimension should be shardable among the devices in data and fsdp"
233243 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
234244 )
235- x = wrap_flash_attention (query , key , value , q_segment_ids , kv_segment_ids )
236- x = x [:, :, :query_seq_len , :kv_size ]
245+ x = wrap_flash_attention (query , key , value )
237246 x = _reshape_heads_to_head_dim (x )
238247
239248 return x
0 commit comments