Skip to content

Commit 673e2c5

Browse files
committed
Implement KV cache and lossless optimizations for Wan pipelines
1 parent 384d211 commit 673e2c5

12 files changed

Lines changed: 395 additions & 59 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ flow_shift: 3.0
332332
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
333333
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
334334
use_cfg_cache: False
335+
use_kv_cache: False
335336
use_magcache: False
336337
magcache_thresh: 0.12
337338
magcache_K: 2

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2023 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -286,6 +286,7 @@ flow_shift: 3.0
286286

287287
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
288288
use_cfg_cache: False
289+
use_kv_cache: False
289290

290291
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
291292
guidance_rescale: 0.0

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ boundary_ratio: 0.875
307307

308308
# Diffusion CFG cache (FasterCache-style)
309309
use_cfg_cache: False
310+
use_kv_cache: False
310311
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
311312
# when predicted output change (based on accumulated latent/timestep drift) is small
312313
use_sen_cache: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ flow_shift: 5.0
291291

292292
# Diffusion CFG cache (FasterCache-style)
293293
use_cfg_cache: False
294+
use_kv_cache: False
294295
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
295296
use_sen_cache: False
296297
use_magcache: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ boundary_ratio: 0.875
303303

304304
# Diffusion CFG cache (FasterCache-style)
305305
use_cfg_cache: False
306+
use_kv_cache: False
306307
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
307308
use_sen_cache: False
308309

src/maxdiffusion/models/attention_flax.py

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import math
18-
from typing import Optional, Callable, Tuple
18+
from typing import Optional, Callable, Tuple, Dict
1919
import flax.linen as nn
2020
from flax import nnx
2121
import jax
@@ -1132,6 +1132,7 @@ def __call__(
11321132
encoder_attention_mask: Optional[jax.Array] = None,
11331133
deterministic: bool = True,
11341134
rngs: nnx.Rngs = None,
1135+
cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None,
11351136
) -> jax.Array:
11361137
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
11371138
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
@@ -1146,16 +1147,22 @@ def __call__(
11461147
if not is_i2v_cross_attention:
11471148
with jax.named_scope("query_proj"):
11481149
query_proj = self.query(hidden_states)
1149-
with jax.named_scope("key_proj"):
1150-
key_proj = self.key(encoder_hidden_states)
1151-
with jax.named_scope("value_proj"):
1152-
value_proj = self.value(encoder_hidden_states)
1153-
1150+
11541151
if self.qk_norm:
11551152
with self.conditional_named_scope("attn_q_norm"):
11561153
query_proj = self.norm_q(query_proj)
1157-
with self.conditional_named_scope("attn_k_norm"):
1158-
key_proj = self.norm_k(key_proj)
1154+
1155+
if not is_self_attention and cached_kv is not None and "text" in cached_kv:
1156+
key_proj, value_proj = cached_kv["text"]
1157+
else:
1158+
with jax.named_scope("key_proj"):
1159+
key_proj = self.key(encoder_hidden_states)
1160+
with jax.named_scope("value_proj"):
1161+
value_proj = self.value(encoder_hidden_states)
1162+
1163+
if self.qk_norm:
1164+
with self.conditional_named_scope("attn_k_norm"):
1165+
key_proj = self.norm_k(key_proj)
11591166

11601167
if rotary_emb is not None:
11611168
with self.conditional_named_scope("attn_rope"):
@@ -1213,22 +1220,29 @@ def __call__(
12131220
query_proj_text = query_proj_raw
12141221

12151222
# Text K/V
1216-
with self.conditional_named_scope("proj_key"):
1217-
key_proj_text = self.key(encoder_hidden_states_text)
1218-
if self.qk_norm:
1219-
with self.conditional_named_scope("attn_k_norm"):
1220-
key_proj_text = self.norm_k(key_proj_text)
1221-
with self.conditional_named_scope("proj_value"):
1222-
value_proj_text = self.value(encoder_hidden_states_text)
1223+
if cached_kv is not None and "text" in cached_kv:
1224+
key_proj_text, value_proj_text = cached_kv["text"]
1225+
else:
1226+
with self.conditional_named_scope("proj_key"):
1227+
key_proj_text = self.key(encoder_hidden_states_text)
1228+
if self.qk_norm:
1229+
with self.conditional_named_scope("attn_k_norm"):
1230+
key_proj_text = self.norm_k(key_proj_text)
1231+
with self.conditional_named_scope("proj_value"):
1232+
value_proj_text = self.value(encoder_hidden_states_text)
12231233

12241234
# Image K/V (only if image embeddings are present)
12251235
if encoder_hidden_states_img is not None:
1226-
with self.conditional_named_scope("add_proj_k"):
1227-
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1228-
with self.conditional_named_scope("norm_add_k"):
1229-
key_proj_img = self.norm_added_k(key_proj_img)
1230-
with self.conditional_named_scope("add_proj_v"):
1231-
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1236+
if cached_kv is not None and "image" in cached_kv:
1237+
key_proj_img, value_proj_img = cached_kv["image"]
1238+
else:
1239+
with self.conditional_named_scope("add_proj_k"):
1240+
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1241+
with self.conditional_named_scope("norm_add_k"):
1242+
key_proj_img = self.norm_added_k(key_proj_img)
1243+
with self.conditional_named_scope("add_proj_v"):
1244+
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1245+
12321246
query_proj_img = query_proj_raw
12331247
# Check norm_added_k too
12341248
# Checkpointing
@@ -1267,6 +1281,64 @@ def __call__(
12671281
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
12681282
return hidden_states
12691283

1284+
def compute_kv(
1285+
self,
1286+
encoder_hidden_states: jax.Array,
1287+
encoder_attention_mask: Optional[jax.Array] = None,
1288+
) -> Dict[str, Tuple[jax.Array, jax.Array]]:
1289+
is_i2v_cross_attention = self.added_kv_proj_dim is not None
1290+
1291+
if not is_i2v_cross_attention:
1292+
with jax.named_scope("key_proj"):
1293+
key_proj = self.key(encoder_hidden_states)
1294+
with jax.named_scope("value_proj"):
1295+
value_proj = self.value(encoder_hidden_states)
1296+
1297+
if self.qk_norm:
1298+
with self.conditional_named_scope("attn_k_norm"):
1299+
key_proj = self.norm_k(key_proj)
1300+
1301+
return {"text": (key_proj, value_proj)}
1302+
else:
1303+
# Image embeddings are padded to multiples of 128 for TPU flash attention
1304+
alignment = 128
1305+
if self.image_seq_len is not None:
1306+
image_seq_len_actual = self.image_seq_len
1307+
else:
1308+
image_seq_len_actual = 257
1309+
padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384
1310+
1311+
if encoder_attention_mask is None:
1312+
padded_img_len = image_seq_len_actual
1313+
1314+
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
1315+
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
1316+
1317+
# Text K/V
1318+
with self.conditional_named_scope("proj_key"):
1319+
key_proj_text = self.key(encoder_hidden_states_text)
1320+
if self.qk_norm:
1321+
with self.conditional_named_scope("attn_k_norm"):
1322+
key_proj_text = self.norm_k(key_proj_text)
1323+
with self.conditional_named_scope("proj_value"):
1324+
value_proj_text = self.value(encoder_hidden_states_text)
1325+
1326+
# Image K/V (only if image embeddings are present)
1327+
if encoder_hidden_states_img is not None:
1328+
with self.conditional_named_scope("add_proj_k"):
1329+
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1330+
with self.conditional_named_scope("norm_add_k"):
1331+
key_proj_img = self.norm_added_k(key_proj_img)
1332+
with self.conditional_named_scope("add_proj_v"):
1333+
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1334+
1335+
return {
1336+
"text": (key_proj_text, value_proj_text),
1337+
"image": (key_proj_img, value_proj_img)
1338+
}
1339+
else:
1340+
return {"text": (key_proj_text, value_proj_text)}
1341+
12701342

12711343
class FlaxFluxAttention(nn.Module):
12721344
query_dim: int

0 commit comments

Comments
 (0)