@@ -52,6 +52,24 @@ class AttentionOp(nn.Module):
5252 flash_block_sizes : BlockSizes = None
5353 dtype : DType = jnp .float32
5454
55+ def setup (self ):
56+ if self .attention_kernel == "cudnn_flash_te" :
57+ from transformer_engine .jax .flax .transformer import DotProductAttention # pytype: disable=import-error
58+ self .dpa_layer = DotProductAttention (
59+ head_dim = self .dim_head ,
60+ num_attention_heads = self .heads ,
61+ num_gqa_groups = self .heads ,
62+ attn_mask_type = "no_mask" , # 'no_mask', 'padding', 'causal', or 'padding_causal'
63+ attn_bias_type = "NO_BIAS" , # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
64+ # attention_dropout=self.dropout_rate,
65+ dropout_rng_name = "aqt" ,
66+ dtype = self .dtype ,
67+ # float32_logits=self.float32_logits,
68+ qkv_layout = "BSHD_BSHD_BSHD" , # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
69+ scale_factor = self .scale ,
70+ transpose_batch_sequence = False ,
71+ )
72+
5573 def check_attention_inputs (self , query : Array , key : Array , value : Array ) -> None :
5674 """Check attention inputs."""
5775
@@ -64,16 +82,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None
6482 def apply_attention (self , query : Array , key : Array , value : Array ):
6583 """Routes to different attention kernels."""
6684 self .check_attention_inputs (query , key , value )
67- can_use_flash_attention = (
68- query .shape [1 ] >= self .flash_min_seq_length
69- and key .shape [1 ] >= self .flash_min_seq_length
70- and value .shape [1 ] >= self .flash_min_seq_length
71- )
85+
86+ if self .attention_kernel == "flash" :
87+ can_use_flash_attention = (
88+ query .shape [1 ] >= self .flash_min_seq_length
89+ and key .shape [1 ] >= self .flash_min_seq_length
90+ and value .shape [1 ] >= self .flash_min_seq_length
91+ )
92+ else :
93+ can_use_flash_attention = True
7294
7395 if self .attention_kernel == "dot_product" or self .use_memory_efficient_attention or not can_use_flash_attention :
7496 return self .apply_attention_dot (query , key , value )
7597 elif self .attention_kernel == "flash" :
7698 return self .tpu_flash_attention (query , key * self .scale , value )
99+ elif self .attention_kernel == "cudnn_flash_te" :
100+ return self .cudnn_flash_attention (query , key , value )
77101 else :
78102 raise ValueError (f"Unexpected attention kernel { self .attention_kernel = } ." )
79103
@@ -132,6 +156,32 @@ def wrap_flash_attention(query, key, value):
132156
133157 return x
134158
159+ def cudnn_flash_attention (
160+ self ,
161+ query : Array ,
162+ key : Array ,
163+ value : Array ,
164+ ) -> Array :
165+ """CUDNN Flash Attention with Transformer Engine.
166+ 1. Stable API, supports GQA
167+ 2. Supports head_dim till 128; head_dim=256 support will be added soon
168+ """
169+ # These imports are only meant to work in a GPU build.
170+ # copied from tpu_flash_attention
171+ query = self .reshape_data_for_cudnn_flash (query )
172+ key = self .reshape_data_for_cudnn_flash (key )
173+ value = self .reshape_data_for_cudnn_flash (value )
174+
175+ cudnn_flash_axis_names = (BATCH , LENGTH , HEAD , D_KV )
176+ axis_names = nn .logical_to_mesh_axes (cudnn_flash_axis_names )
177+
178+ query = nn .with_logical_constraint (query , axis_names )
179+ key = nn .with_logical_constraint (key , axis_names )
180+ value = nn .with_logical_constraint (value , axis_names )
181+
182+ out = self .dpa_layer (query , key , value , mask = None )
183+ return self .reshape_data_from_cudnn_flash (out )
184+
135185 def apply_attention_dot (self , query : Array , key : Array , value : Array ):
136186 """Apply Attention."""
137187 if self .split_head_dim :
@@ -209,6 +259,16 @@ def reshape_batch_dim_to_heads(self, tensor):
209259 tensor = tensor .reshape (batch_size // head_size , seq_len , dim * head_size )
210260 return tensor
211261
262+ def reshape_data_for_cudnn_flash (self , tensor ):
263+ # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
264+ batch , seq , heads_and_dim_head = tensor .shape
265+ tensor = tensor .reshape (batch , seq , self .heads , heads_and_dim_head // self .heads )
266+ return tensor
267+
268+ def reshape_data_from_cudnn_flash (self , tensor ):
269+ # reshapes from [b, s, h, d] back to [b, s, h * d]
270+ return tensor .reshape (tensor .shape [0 ], tensor .shape [1 ], - 1 )
271+
212272 def reshape_data_for_flash (self , tensor ):
213273 # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
214274 batch , seq , heads_and_dim_head = tensor .shape
0 commit comments