@@ -283,6 +283,7 @@ def wrap_flash_attention(query, key, value):
283283
284284 block_kv = max (* block_kv_sizes )
285285 key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
286+ print ("Key seq len after padding:" , key_seq_len )
286287 value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
287288
288289 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
@@ -293,8 +294,10 @@ def wrap_flash_attention(query, key, value):
293294 q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
294295
295296 kv_padded_len = key .shape [2 ]
297+ print ("KV padded len:" , kv_padded_len )
296298 kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
297299 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
300+ print ("KV segment ids:" , kv_segment_ids )
298301 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
299302
300303 # make_splash_mha is wrapped around shardmap and seq and head is already
@@ -1008,8 +1011,10 @@ def __call__(
10081011 query_proj = self .query (hidden_states )
10091012 with jax .named_scope ("key_proj" ):
10101013 key_proj = self .key (encoder_hidden_states )
1014+ print ("key_proj shape:" , key_proj .shape )
10111015 with jax .named_scope ("value_proj" ):
10121016 value_proj = self .value (encoder_hidden_states )
1017+ print ("value_proj shape:" , value_proj .shape )
10131018
10141019 if self .qk_norm :
10151020 with self .conditional_named_scope ("attn_q_norm" ):
0 commit comments