1515import contextlib
1616import functools
1717import math
18- from typing import Optional , Callable , Tuple
18+ from typing import Optional , Callable , Tuple , Dict
1919import flax .linen as nn
2020from flax import nnx
2121import jax
@@ -1132,6 +1132,7 @@ def __call__(
11321132 encoder_attention_mask : Optional [jax .Array ] = None ,
11331133 deterministic : bool = True ,
11341134 rngs : nnx .Rngs = None ,
1135+ cached_kv : Optional [Dict [str , Tuple [jax .Array , jax .Array ]]] = None ,
11351136 ) -> jax .Array :
11361137 axis_names = nn .logical_to_mesh_axes ((BATCH , LENGTH , HEAD ))
11371138 hidden_states = jax .lax .with_sharding_constraint (hidden_states , axis_names )
@@ -1146,16 +1147,22 @@ def __call__(
11461147 if not is_i2v_cross_attention :
11471148 with jax .named_scope ("query_proj" ):
11481149 query_proj = self .query (hidden_states )
1149- with jax .named_scope ("key_proj" ):
1150- key_proj = self .key (encoder_hidden_states )
1151- with jax .named_scope ("value_proj" ):
1152- value_proj = self .value (encoder_hidden_states )
1153-
1150+
11541151 if self .qk_norm :
11551152 with self .conditional_named_scope ("attn_q_norm" ):
11561153 query_proj = self .norm_q (query_proj )
1157- with self .conditional_named_scope ("attn_k_norm" ):
1158- key_proj = self .norm_k (key_proj )
1154+
1155+ if not is_self_attention and cached_kv is not None and "text" in cached_kv :
1156+ key_proj , value_proj = cached_kv ["text" ]
1157+ else :
1158+ with jax .named_scope ("key_proj" ):
1159+ key_proj = self .key (encoder_hidden_states )
1160+ with jax .named_scope ("value_proj" ):
1161+ value_proj = self .value (encoder_hidden_states )
1162+
1163+ if self .qk_norm :
1164+ with self .conditional_named_scope ("attn_k_norm" ):
1165+ key_proj = self .norm_k (key_proj )
11591166
11601167 if rotary_emb is not None :
11611168 with self .conditional_named_scope ("attn_rope" ):
@@ -1213,22 +1220,29 @@ def __call__(
12131220 query_proj_text = query_proj_raw
12141221
12151222 # Text K/V
1216- with self .conditional_named_scope ("proj_key" ):
1217- key_proj_text = self .key (encoder_hidden_states_text )
1218- if self .qk_norm :
1219- with self .conditional_named_scope ("attn_k_norm" ):
1220- key_proj_text = self .norm_k (key_proj_text )
1221- with self .conditional_named_scope ("proj_value" ):
1222- value_proj_text = self .value (encoder_hidden_states_text )
1223+ if cached_kv is not None and "text" in cached_kv :
1224+ key_proj_text , value_proj_text = cached_kv ["text" ]
1225+ else :
1226+ with self .conditional_named_scope ("proj_key" ):
1227+ key_proj_text = self .key (encoder_hidden_states_text )
1228+ if self .qk_norm :
1229+ with self .conditional_named_scope ("attn_k_norm" ):
1230+ key_proj_text = self .norm_k (key_proj_text )
1231+ with self .conditional_named_scope ("proj_value" ):
1232+ value_proj_text = self .value (encoder_hidden_states_text )
12231233
12241234 # Image K/V (only if image embeddings are present)
12251235 if encoder_hidden_states_img is not None :
1226- with self .conditional_named_scope ("add_proj_k" ):
1227- key_proj_img = self .add_k_proj (encoder_hidden_states_img )
1228- with self .conditional_named_scope ("norm_add_k" ):
1229- key_proj_img = self .norm_added_k (key_proj_img )
1230- with self .conditional_named_scope ("add_proj_v" ):
1231- value_proj_img = self .add_v_proj (encoder_hidden_states_img )
1236+ if cached_kv is not None and "image" in cached_kv :
1237+ key_proj_img , value_proj_img = cached_kv ["image" ]
1238+ else :
1239+ with self .conditional_named_scope ("add_proj_k" ):
1240+ key_proj_img = self .add_k_proj (encoder_hidden_states_img )
1241+ with self .conditional_named_scope ("norm_add_k" ):
1242+ key_proj_img = self .norm_added_k (key_proj_img )
1243+ with self .conditional_named_scope ("add_proj_v" ):
1244+ value_proj_img = self .add_v_proj (encoder_hidden_states_img )
1245+
12321246 query_proj_img = query_proj_raw
12331247 # Check norm_added_k too
12341248 # Checkpointing
@@ -1267,6 +1281,64 @@ def __call__(
12671281 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
12681282 return hidden_states
12691283
1284+ def compute_kv (
1285+ self ,
1286+ encoder_hidden_states : jax .Array ,
1287+ encoder_attention_mask : Optional [jax .Array ] = None ,
1288+ ) -> Dict [str , Tuple [jax .Array , jax .Array ]]:
1289+ is_i2v_cross_attention = self .added_kv_proj_dim is not None
1290+
1291+ if not is_i2v_cross_attention :
1292+ with jax .named_scope ("key_proj" ):
1293+ key_proj = self .key (encoder_hidden_states )
1294+ with jax .named_scope ("value_proj" ):
1295+ value_proj = self .value (encoder_hidden_states )
1296+
1297+ if self .qk_norm :
1298+ with self .conditional_named_scope ("attn_k_norm" ):
1299+ key_proj = self .norm_k (key_proj )
1300+
1301+ return {"text" : (key_proj , value_proj )}
1302+ else :
1303+ # Image embeddings are padded to multiples of 128 for TPU flash attention
1304+ alignment = 128
1305+ if self .image_seq_len is not None :
1306+ image_seq_len_actual = self .image_seq_len
1307+ else :
1308+ image_seq_len_actual = 257
1309+ padded_img_len = ((image_seq_len_actual + alignment - 1 ) // alignment ) * alignment # 257 -> 384
1310+
1311+ if encoder_attention_mask is None :
1312+ padded_img_len = image_seq_len_actual
1313+
1314+ encoder_hidden_states_img = encoder_hidden_states [:, :padded_img_len , :]
1315+ encoder_hidden_states_text = encoder_hidden_states [:, padded_img_len :, :]
1316+
1317+ # Text K/V
1318+ with self .conditional_named_scope ("proj_key" ):
1319+ key_proj_text = self .key (encoder_hidden_states_text )
1320+ if self .qk_norm :
1321+ with self .conditional_named_scope ("attn_k_norm" ):
1322+ key_proj_text = self .norm_k (key_proj_text )
1323+ with self .conditional_named_scope ("proj_value" ):
1324+ value_proj_text = self .value (encoder_hidden_states_text )
1325+
1326+ # Image K/V (only if image embeddings are present)
1327+ if encoder_hidden_states_img is not None :
1328+ with self .conditional_named_scope ("add_proj_k" ):
1329+ key_proj_img = self .add_k_proj (encoder_hidden_states_img )
1330+ with self .conditional_named_scope ("norm_add_k" ):
1331+ key_proj_img = self .norm_added_k (key_proj_img )
1332+ with self .conditional_named_scope ("add_proj_v" ):
1333+ value_proj_img = self .add_v_proj (encoder_hidden_states_img )
1334+
1335+ return {
1336+ "text" : (key_proj_text , value_proj_text ),
1337+ "image" : (key_proj_img , value_proj_img )
1338+ }
1339+ else :
1340+ return {"text" : (key_proj_text , value_proj_text )}
1341+
12701342
12711343class FlaxFluxAttention (nn .Module ):
12721344 query_dim : int
0 commit comments