Skip to content

Commit b279c7d

Browse files
committed
Implement KV cache for cross-attention in Wan T2V and I2V pipelines
1 parent abb97c3 commit b279c7d

10 files changed

Lines changed: 255 additions & 46 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ flow_shift: 3.0
328328
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
329329
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
330330
use_cfg_cache: False
331+
use_kv_cache: False
331332
use_magcache: False
332333
magcache_thresh: 0.12
333334
magcache_K: 2

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2023 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -282,6 +282,7 @@ flow_shift: 3.0
282282

283283
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
284284
use_cfg_cache: False
285+
use_kv_cache: False
285286

286287
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
287288
guidance_rescale: 0.0

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ boundary_ratio: 0.875
304304

305305
# Diffusion CFG cache (FasterCache-style)
306306
use_cfg_cache: False
307+
use_kv_cache: False
307308
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
308309
# when predicted output change (based on accumulated latent/timestep drift) is small
309310
use_sen_cache: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ flow_shift: 5.0
288288

289289
# Diffusion CFG cache (FasterCache-style)
290290
use_cfg_cache: False
291+
use_kv_cache: False
291292
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
292293
use_sen_cache: False
293294
use_magcache: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ boundary_ratio: 0.875
300300

301301
# Diffusion CFG cache (FasterCache-style)
302302
use_cfg_cache: False
303+
use_kv_cache: False
303304
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
304305
use_sen_cache: False
305306

src/maxdiffusion/models/attention_flax.py

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import math
18-
from typing import Optional, Callable, Tuple
18+
from typing import Optional, Callable, Tuple, Dict
1919
import flax.linen as nn
2020
from flax import nnx
2121
import jax
@@ -1130,6 +1130,7 @@ def __call__(
11301130
encoder_attention_mask: Optional[jax.Array] = None,
11311131
deterministic: bool = True,
11321132
rngs: nnx.Rngs = None,
1133+
cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None,
11331134
) -> jax.Array:
11341135
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
11351136
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
@@ -1144,16 +1145,22 @@ def __call__(
11441145
if not is_i2v_cross_attention:
11451146
with jax.named_scope("query_proj"):
11461147
query_proj = self.query(hidden_states)
1147-
with jax.named_scope("key_proj"):
1148-
key_proj = self.key(encoder_hidden_states)
1149-
with jax.named_scope("value_proj"):
1150-
value_proj = self.value(encoder_hidden_states)
1151-
1148+
11521149
if self.qk_norm:
11531150
with self.conditional_named_scope("attn_q_norm"):
11541151
query_proj = self.norm_q(query_proj)
1155-
with self.conditional_named_scope("attn_k_norm"):
1156-
key_proj = self.norm_k(key_proj)
1152+
1153+
if not is_self_attention and cached_kv is not None and "text" in cached_kv:
1154+
key_proj, value_proj = cached_kv["text"]
1155+
else:
1156+
with jax.named_scope("key_proj"):
1157+
key_proj = self.key(encoder_hidden_states)
1158+
with jax.named_scope("value_proj"):
1159+
value_proj = self.value(encoder_hidden_states)
1160+
1161+
if self.qk_norm:
1162+
with self.conditional_named_scope("attn_k_norm"):
1163+
key_proj = self.norm_k(key_proj)
11571164

11581165
if rotary_emb is not None:
11591166
with self.conditional_named_scope("attn_rope"):
@@ -1211,22 +1218,29 @@ def __call__(
12111218
query_proj_text = query_proj_raw
12121219

12131220
# Text K/V
1214-
with self.conditional_named_scope("proj_key"):
1215-
key_proj_text = self.key(encoder_hidden_states_text)
1216-
if self.qk_norm:
1217-
with self.conditional_named_scope("attn_k_norm"):
1218-
key_proj_text = self.norm_k(key_proj_text)
1219-
with self.conditional_named_scope("proj_value"):
1220-
value_proj_text = self.value(encoder_hidden_states_text)
1221+
if cached_kv is not None and "text" in cached_kv:
1222+
key_proj_text, value_proj_text = cached_kv["text"]
1223+
else:
1224+
with self.conditional_named_scope("proj_key"):
1225+
key_proj_text = self.key(encoder_hidden_states_text)
1226+
if self.qk_norm:
1227+
with self.conditional_named_scope("attn_k_norm"):
1228+
key_proj_text = self.norm_k(key_proj_text)
1229+
with self.conditional_named_scope("proj_value"):
1230+
value_proj_text = self.value(encoder_hidden_states_text)
12211231

12221232
# Image K/V (only if image embeddings are present)
12231233
if encoder_hidden_states_img is not None:
1224-
with self.conditional_named_scope("add_proj_k"):
1225-
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1226-
with self.conditional_named_scope("norm_add_k"):
1227-
key_proj_img = self.norm_added_k(key_proj_img)
1228-
with self.conditional_named_scope("add_proj_v"):
1229-
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1234+
if cached_kv is not None and "image" in cached_kv:
1235+
key_proj_img, value_proj_img = cached_kv["image"]
1236+
else:
1237+
with self.conditional_named_scope("add_proj_k"):
1238+
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1239+
with self.conditional_named_scope("norm_add_k"):
1240+
key_proj_img = self.norm_added_k(key_proj_img)
1241+
with self.conditional_named_scope("add_proj_v"):
1242+
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1243+
12301244
query_proj_img = query_proj_raw
12311245
# Check norm_added_k too
12321246
# Checkpointing
@@ -1264,6 +1278,64 @@ def __call__(
12641278
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
12651279
return hidden_states
12661280

1281+
def compute_kv(
1282+
self,
1283+
encoder_hidden_states: jax.Array,
1284+
encoder_attention_mask: Optional[jax.Array] = None,
1285+
) -> Dict[str, Tuple[jax.Array, jax.Array]]:
1286+
is_i2v_cross_attention = self.added_kv_proj_dim is not None
1287+
1288+
if not is_i2v_cross_attention:
1289+
with jax.named_scope("key_proj"):
1290+
key_proj = self.key(encoder_hidden_states)
1291+
with jax.named_scope("value_proj"):
1292+
value_proj = self.value(encoder_hidden_states)
1293+
1294+
if self.qk_norm:
1295+
with self.conditional_named_scope("attn_k_norm"):
1296+
key_proj = self.norm_k(key_proj)
1297+
1298+
return {"text": (key_proj, value_proj)}
1299+
else:
1300+
# Image embeddings are padded to multiples of 128 for TPU flash attention
1301+
alignment = 128
1302+
if self.image_seq_len is not None:
1303+
image_seq_len_actual = self.image_seq_len
1304+
else:
1305+
image_seq_len_actual = 257
1306+
padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384
1307+
1308+
if encoder_attention_mask is None:
1309+
padded_img_len = image_seq_len_actual
1310+
1311+
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
1312+
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
1313+
1314+
# Text K/V
1315+
with self.conditional_named_scope("proj_key"):
1316+
key_proj_text = self.key(encoder_hidden_states_text)
1317+
if self.qk_norm:
1318+
with self.conditional_named_scope("attn_k_norm"):
1319+
key_proj_text = self.norm_k(key_proj_text)
1320+
with self.conditional_named_scope("proj_value"):
1321+
value_proj_text = self.value(encoder_hidden_states_text)
1322+
1323+
# Image K/V (only if image embeddings are present)
1324+
if encoder_hidden_states_img is not None:
1325+
with self.conditional_named_scope("add_proj_k"):
1326+
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1327+
with self.conditional_named_scope("norm_add_k"):
1328+
key_proj_img = self.norm_added_k(key_proj_img)
1329+
with self.conditional_named_scope("add_proj_v"):
1330+
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1331+
1332+
return {
1333+
"text": (key_proj_text, value_proj_text),
1334+
"image": (key_proj_img, value_proj_img)
1335+
}
1336+
else:
1337+
return {"text": (key_proj_text, value_proj_text)}
1338+
12671339

12681340
class FlaxFluxAttention(nn.Module):
12691341
query_dim: int

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def __call__(
374374
deterministic: bool = True,
375375
rngs: nnx.Rngs = None,
376376
encoder_attention_mask: Optional[jax.Array] = None,
377+
cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None,
377378
):
378379
with self.conditional_named_scope("transformer_block"):
379380
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
@@ -413,6 +414,7 @@ def __call__(
413414
deterministic=deterministic,
414415
rngs=rngs,
415416
encoder_attention_mask=encoder_attention_mask,
417+
cached_kv=cached_kv,
416418
)
417419
with self.conditional_named_scope("cross_attn_residual"):
418420
hidden_states = hidden_states + attn_output
@@ -431,6 +433,13 @@ def __call__(
431433
)
432434
return hidden_states
433435

436+
def compute_kv(
437+
self,
438+
encoder_hidden_states: jax.Array,
439+
encoder_attention_mask: Optional[jax.Array] = None,
440+
) -> Dict[str, Tuple[jax.Array, jax.Array]]:
441+
return self.attn2.compute_kv(encoder_hidden_states, encoder_attention_mask)
442+
434443

435444
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
436445

@@ -583,6 +592,53 @@ def conditional_named_scope(self, name: str):
583592
"""Return a JAX named scope if enabled, otherwise a null context."""
584593
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
585594

595+
def compute_kv_cache(
596+
self,
597+
encoder_hidden_states: jax.Array,
598+
encoder_hidden_states_image: Optional[jax.Array] = None,
599+
timestep: Optional[jax.Array] = None,
600+
) -> Dict[str, Tuple[jax.Array, jax.Array]]:
601+
if timestep is None:
602+
batch_size = encoder_hidden_states.shape[0]
603+
timestep = jnp.zeros((batch_size,), dtype=jnp.int32)
604+
605+
with self.conditional_named_scope("condition_embedder"):
606+
(
607+
temb,
608+
timestep_proj,
609+
encoder_hidden_states,
610+
encoder_hidden_states_image,
611+
encoder_attention_mask,
612+
) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
613+
614+
if encoder_hidden_states_image is not None:
615+
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
616+
if encoder_attention_mask is not None:
617+
text_mask = jnp.ones(
618+
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]),
619+
dtype=jnp.int32,
620+
)
621+
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
622+
623+
if self.scan_layers:
624+
@nnx.vmap(in_axes=(0, None, None), out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"})
625+
def _compute_kv(block, enc_states, enc_mask):
626+
return block.compute_kv(enc_states, enc_mask)
627+
628+
kv_cache = _compute_kv(self.blocks, encoder_hidden_states, encoder_attention_mask)
629+
else:
630+
kv_cache_list = []
631+
for block in self.blocks:
632+
kv_cache_list.append(block.compute_kv(encoder_hidden_states, encoder_attention_mask))
633+
keys = kv_cache_list[0].keys()
634+
kv_cache = {}
635+
for k in keys:
636+
k_list = [d[k][0] for d in kv_cache_list]
637+
v_list = [d[k][1] for d in kv_cache_list]
638+
kv_cache[k] = (jnp.stack(k_list, axis=0), jnp.stack(v_list, axis=0))
639+
640+
return kv_cache
641+
586642
@jax.named_scope("WanModel")
587643
def __call__(
588644
self,
@@ -597,6 +653,7 @@ def __call__(
597653
skip_blocks: Optional[jax.Array] = None,
598654
cached_residual: Optional[jax.Array] = None,
599655
return_residual: bool = False,
656+
kv_cache: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None,
600657
) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]:
601658
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
602659
batch_size, _, num_frames, height, width = hidden_states.shape
@@ -634,8 +691,14 @@ def __call__(
634691
def _run_all_blocks(h):
635692
if self.scan_layers:
636693

637-
def scan_fn(carry, block):
694+
def scan_fn(carry, block_input):
638695
hidden_states_carry, rngs_carry = carry
696+
if kv_cache is not None:
697+
block, layer_kv_cache = block_input
698+
else:
699+
block = block_input
700+
layer_kv_cache = None
701+
639702
hidden_states = block(
640703
hidden_states_carry,
641704
encoder_hidden_states,
@@ -644,6 +707,7 @@ def scan_fn(carry, block):
644707
deterministic,
645708
rngs_carry,
646709
encoder_attention_mask,
710+
cached_kv=layer_kv_cache,
647711
)
648712
new_carry = (hidden_states, rngs_carry)
649713
return new_carry, None
@@ -652,19 +716,28 @@ def scan_fn(carry, block):
652716
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
653717
)
654718
initial_carry = (h, rngs)
719+
720+
if kv_cache is not None:
721+
scan_input = (self.blocks, kv_cache)
722+
else:
723+
scan_input = self.blocks
724+
655725
final_carry, _ = nnx.scan(
656726
rematted_block_forward,
657727
length=self.num_layers,
658728
in_axes=(nnx.Carry, 0),
659729
out_axes=(nnx.Carry, 0),
660-
)(initial_carry, self.blocks)
730+
)(initial_carry, scan_input)
661731

662732
h_out, _ = final_carry
663733
else:
664734
h_out = h
665-
for block in self.blocks:
735+
for i, block in enumerate(self.blocks):
736+
layer_kv_cache = None
737+
if kv_cache is not None:
738+
layer_kv_cache = jax.tree_map(lambda x: x[i], kv_cache)
666739

667-
def layer_forward(hidden_states):
740+
def layer_forward(hidden_states, l_kv):
668741
return block(
669742
hidden_states,
670743
encoder_hidden_states,
@@ -673,6 +746,7 @@ def layer_forward(hidden_states):
673746
deterministic,
674747
rngs,
675748
encoder_attention_mask=encoder_attention_mask,
749+
cached_kv=l_kv,
676750
)
677751

678752
rematted_layer_forward = self.gradient_checkpoint.apply(
@@ -681,7 +755,7 @@ def layer_forward(hidden_states):
681755
self.names_which_can_be_offloaded,
682756
prevent_cse=not self.scan_layers,
683757
)
684-
h_out = rematted_layer_forward(h_out)
758+
h_out = rematted_layer_forward(h_out, layer_kv_cache)
685759
return h_out
686760

687761
hidden_states_before_blocks = hidden_states

0 commit comments

Comments
 (0)