1919import flax .linen as nn
2020from flax import nnx
2121import jax
22+ from jax .sharding import PartitionSpec
2223import jax .numpy as jnp
2324from jax .experimental import shard_map
2425from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
@@ -139,21 +140,23 @@ def _tpu_flash_attention(
139140 heads : int ,
140141 mesh : Mesh ,
141142 flash_axis_names : AxisNames ,
142- flash_block_sizes : BlockSizes ) -> jax .Array :
143+ flash_block_sizes : BlockSizes ,
144+ dtype : jnp .dtype = jnp .float32 ) -> jax .Array :
143145 """TPU Flash Attention"""
144146
147+ max_block_size = 1024 if dtype == jnp .bfloat16 else 512
145148 if flash_block_sizes :
146149 block_sizes = flash_block_sizes
147150 else :
148151 block_sizes = splash_attention_kernel .BlockSizes (
149- block_q = min (512 , query .shape [2 ]),
150- block_kv_compute = min (512 , key .shape [2 ]),
151- block_kv = min (512 , key .shape [2 ]),
152- block_q_dkv = min (512 , query .shape [2 ]),
153- block_kv_dkv = min (512 , key .shape [2 ]),
154- block_kv_dkv_compute = min (512 , query .shape [2 ]),
155- block_q_dq = min (512 , query .shape [2 ]),
156- block_kv_dq = min (512 , query .shape [2 ]),
152+ block_q = min (max_block_size , query .shape [2 ]),
153+ block_kv_compute = min (max_block_size , key .shape [2 ]),
154+ block_kv = min (max_block_size , key .shape [2 ]),
155+ block_q_dkv = min (max_block_size , query .shape [2 ]),
156+ block_kv_dkv = min (max_block_size , key .shape [2 ]),
157+ block_kv_dkv_compute = min (max_block_size , query .shape [2 ]),
158+ block_q_dq = min (max_block_size , query .shape [2 ]),
159+ block_kv_dq = min (max_block_size , query .shape [2 ]),
157160 )
158161
159162 query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q )
@@ -340,7 +343,7 @@ def _apply_attention(
340343 if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention :
341344 return _apply_attention_dot (query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention )
342345 elif attention_kernel == "flash" :
343- return _tpu_flash_attention (query , key * scale , value , heads , mesh , flash_axis_names , flash_block_sizes )
346+ return _tpu_flash_attention (query , key * scale , value , heads , mesh , flash_axis_names , flash_block_sizes , dtype )
344347 elif attention_kernel == "cudnn_flash_te" :
345348 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
346349 else :
@@ -668,15 +671,15 @@ def __init__(
668671 rngs = rngs ,
669672 epsilon = eps ,
670673 dtype = dtype ,
671- scale_init = nnx .with_partitioning (nnx .initializers .ones , ("heads " , )),
674+ scale_init = nnx .with_partitioning (nnx .initializers .ones , ("norm " , )),
672675 param_dtype = weights_dtype
673676 )
674677
675678 self .norm_k = nnx .RMSNorm (
676679 num_features = self .inner_dim ,
677680 rngs = rngs ,
678681 dtype = dtype ,
679- scale_init = nnx .with_partitioning (nnx .initializers .ones , ("heads " , )),
682+ scale_init = nnx .with_partitioning (nnx .initializers .ones , ("norm " , )),
680683 param_dtype = weights_dtype
681684 )
682685
@@ -702,9 +705,12 @@ def __call__(
702705 encoder_hidden_states : jax .Array = None ,
703706 rotary_emb : Optional [jax .Array ] = None
704707 ) -> jax .Array :
708+ hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ('data' , 'fsdp' ,'tensor' ))
709+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ('data' , 'fsdp' ,'tensor' ))
705710 dtype = hidden_states .dtype
706711 if encoder_hidden_states is None :
707712 encoder_hidden_states = hidden_states
713+
708714 query_proj = self .query (hidden_states )
709715 key_proj = self .key (encoder_hidden_states )
710716 value_proj = self .value (encoder_hidden_states )
@@ -717,8 +723,13 @@ def __call__(
717723 key_proj = _unflatten_heads (key_proj , self .heads )
718724 value_proj = _unflatten_heads (value_proj , self .heads )
719725 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
720-
726+ query_proj = jax .lax .with_sharding_constraint (query_proj , PartitionSpec ('data' , 'tensor' , None , None ))
727+ key_proj = jax .lax .with_sharding_constraint (key_proj , PartitionSpec ('data' , 'tensor' , None , None ))
728+ value_proj = jax .lax .with_sharding_constraint (value_proj , PartitionSpec ('data' , 'tensor' , None , None ))
729+
721730 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
731+ attn_output = jax .lax .with_sharding_constraint (attn_output , PartitionSpec ('data' , None , None ))
732+
722733 attn_output = attn_output .astype (dtype = dtype )
723734
724735 hidden_states = self .proj_attn (attn_output )
0 commit comments