@@ -76,8 +76,8 @@ def _reshape_batch_dim_to_heads(tensor, heads):
7676 head_size = heads
7777 tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
7878 tensor = jnp .transpose (tensor , (0 , 2 , 1 , 3 ))
79- tensor = tensor .reshape (batch_size // head_size , seq_len , dim * head_size )
80- return tensor
79+ reshaped_tensor = tensor .reshape (batch_size // head_size , seq_len , dim * head_size )
80+ return jax . lax . with_sharding_constraint ( reshaped_tensor , PartitionSpec ( "data" , "fsdp" , " tensor" ))
8181
8282
8383def _reshape_heads_to_batch_dim (tensor , heads ):
@@ -86,12 +86,12 @@ def _reshape_heads_to_batch_dim(tensor, heads):
8686 head_size = heads
8787 tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
8888 tensor = jnp .transpose (tensor , (0 , 2 , 1 , 3 ))
89- tensor = tensor .reshape (batch_size * head_size , seq_len , dim // head_size )
89+ reshaped_tensor = tensor .reshape (batch_size * head_size , seq_len , dim // head_size )
9090 else :
9191 batch_size , head_size , seq_len , head_dim = tensor .shape
92- tensor = tensor .reshape (batch_size * head_size , seq_len , head_dim )
92+ reshaped_tensor = tensor .reshape (batch_size * head_size , seq_len , head_dim )
9393
94- return tensor
94+ return jax . lax . with_sharding_constraint ( reshaped_tensor , PartitionSpec ( "data" , "fsdp" , " tensor" ))
9595
9696
9797def _reshape_heads_to_head_dim (tensor ):
@@ -140,14 +140,15 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
140140 # 2. Ensure num_blocks is divisible by num_shards
141141 num_blocks = seq_len_padded_pre // flash_block_size
142142 if num_blocks % num_shards != 0 :
143- num_blocks += ( num_shards - (num_blocks % num_shards ) )
143+ num_blocks += num_shards - (num_blocks % num_shards )
144144
145145 final_padded_len = num_blocks * flash_block_size
146146 seq_len_pad = final_padded_len - seq_len
147147
148148 if kv_size < 128 or seq_len_pad != 0 :
149149 npad = ((0 , 0 ), (0 , 0 ), (0 , seq_len_pad ), (0 , head_dim_pad ))
150- tensor = jnp .pad (tensor , npad )
150+ padded_tensor = jnp .pad (tensor , npad )
151+ tensor = jax .lax .with_sharding_constraint (padded_tensor , PartitionSpec ("data" , "fsdp" , "tensor" ))
151152
152153 return tensor , kv_size , seq_len
153154
@@ -189,40 +190,38 @@ def _tpu_flash_attention(
189190 flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
190191 axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
191192 named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
192-
193- shard_head_size = mesh .shape [' tensor' ]
193+
194+ shard_head_size = mesh .shape [" tensor" ]
194195
195196 @functools .partial (
196197 jax .jit ,
197- static_argnames = [
198- "multi_head_mask" ,
199- "shard_head_size"
200- ],
198+ static_argnames = ["multi_head_mask" , "shard_head_size" ],
201199 )
202200 def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
203201 splash_kernel = splash_attention_kernel .make_splash_mha (
204- mask = multi_head_mask ,
205- head_shards = shard_head_size , # the sizes of the axis is sharding over heads
206- q_seq_shards = num_fsdp_shards , # the sizes of the axis is sharding over seq_len
207- block_sizes = block_sizes ,
202+ mask = multi_head_mask ,
203+ head_shards = shard_head_size , # the sizes of the axis is sharding over heads
204+ q_seq_shards = num_fsdp_shards , # the sizes of the axis is sharding over seq_len
205+ block_sizes = block_sizes ,
208206 )
209207 return splash_kernel
210208
211209 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
212210 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
213211 splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
214212 segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
213+
215214 @functools .partial (
216- shard_map .shard_map ,
217- mesh = mesh ,
218- in_specs = (
219- q_axis_names ,
220- kv_axis_names ,
221- kv_axis_names ,
222- segment_axis_names_splash_kernel ,
223- ),
224- out_specs = q_axis_names ,
225- check_rep = False
215+ shard_map .shard_map ,
216+ mesh = mesh ,
217+ in_specs = (
218+ q_axis_names ,
219+ kv_axis_names ,
220+ kv_axis_names ,
221+ segment_axis_names_splash_kernel ,
222+ ),
223+ out_specs = q_axis_names ,
224+ check_rep = False ,
226225 )
227226 def wrap_flash_attention (query , key , value , splash_kernel ):
228227 attention_output = jax .vmap (splash_kernel )(query , key , value )
@@ -386,7 +385,9 @@ def _apply_attention(
386385 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
387386 )
388387 elif attention_kernel == "flash" :
389- return _tpu_flash_attention (query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype )
388+ return _tpu_flash_attention (
389+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
390+ )
390391 elif attention_kernel == "cudnn_flash_te" :
391392 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
392393 else :
@@ -566,8 +567,8 @@ class AttentionOp(nn.Module):
566567 use_memory_efficient_attention : bool = False
567568 split_head_dim : bool = False
568569 float32_qk_product : bool = True
569- axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
570- axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
570+ axis_names_q : AxisNames = (( BATCH , HEAD , LENGTH , D_KV ),)
571+ axis_names_kv : AxisNames = (( BATCH , HEAD , KV_LENGTH , D_KV ),)
571572 flash_min_seq_length : int = 4096
572573 flash_block_sizes : BlockSizes = None
573574 dtype : DType = jnp .float32
@@ -775,7 +776,6 @@ def __call__(
775776 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
776777
777778 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
778- attn_output = jax .lax .with_sharding_constraint (attn_output , PartitionSpec ("data" , None , None ))
779779
780780 attn_output = attn_output .astype (dtype = dtype )
781781
0 commit comments