Skip to content

Commit 1abc00c

Browse files
wrap up attention.
1 parent b9b2465 commit 1abc00c

2 files changed

Lines changed: 106 additions & 62 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 26 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,28 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size):
101101
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
102102
"""
103103
tensor = _unflatten_heads(tensor, heads)
104+
105+
# pad head_dim to 128 if less than that.
104106
kv_size = tensor.shape[-1]
107+
head_dim_pad = 0
105108
if kv_size < 128:
106-
npad = ((0, 0), (0, 0), (0, 0), (0, 128 - kv_size))
107-
tensor = jnp.pad(tensor, npad)
109+
head_dim_pad = 128 - kv_size
110+
111+
# pad seq_len to a multiple of flash_block_size if needed.
108112
seq_len = tensor.shape[2]
113+
# remainder
109114
rem = seq_len % flash_block_size
115+
seq_len_pad = 0
110116
if rem != 0:
117+
# multiplier
111118
mul = seq_len // flash_block_size
112-
npad = ((0, 0), (0, 0), (0, (mul + 1)*flash_block_size - seq_len), (0, 0))
119+
# pad to the closest multiplier of flash_block_size
120+
seq_len_pad = (mul + 1) * flash_block_size - seq_len
121+
122+
if kv_size < 128 or rem != 0:
123+
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
113124
tensor = jnp.pad(tensor, npad)
125+
114126
return tensor, kv_size, seq_len
115127

116128
def _tpu_flash_attention(
@@ -140,15 +152,7 @@ def _tpu_flash_attention(
140152
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q)
141153
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute)
142154
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute)
143-
# query_seq_len = query.shape[2]
144-
# query_rem = query_seq_len % block_sizes.block_q
145-
# if query_rem != 0:
146-
# query_mul = query_seq_len // block_sizes.block_q
147-
# npad = ((0, 0), (0, 0), (0, (query_mul + 1)*block_sizes.block_q - query.shape[2]), (0, 0))
148-
# query = jnp.pad(query, npad)
149-
# key = jnp.pad(key, npad)
150-
# value = jnp.pad(value, npad)
151-
# breakpoint()
155+
152156
axis_names = nn.logical_to_mesh_axes(flash_axis_names)
153157

154158
@functools.partial(
@@ -456,7 +460,7 @@ def __init__(
456460
):
457461
self.dpa_layer = None
458462
if attention_kernel == "cudnn_flash_te":
459-
raise NotImplementedError("Wan 2.1 has not been tested with cudnn_flash_te")
463+
raise NotImplementedError(f"{self} has not been tested with {attention_kernel}")
460464

461465
self.mesh = mesh
462466
self.scale = scale
@@ -574,34 +578,13 @@ def __init__(
574578
qkv_bias: bool = False,
575579
quant: Quant = None,
576580
):
577-
# TODO - Params from pytorch implementation
578-
# to set for the creation of this.
579-
# bias is True
580-
# upcast_attention - False
581-
# upcast_softmax - False
582-
# cross_attention_norm - None
583-
# cross_attention_norm_num_groups - 32
584-
# qk_norm - rms_norm_across_heads
585-
# added_kv_proj_dim
586-
# norm_num_groups: Optional[int] = None,
587-
# spatial_norm_dim: Optional[int] = None,
588-
# out_bias: bool = True,
589-
# scale_qk: bool = True,
590-
# only_cross_attention - False
591-
# eps - 1e-06
592-
# rescale_output_factor: float = 1.0,
593-
# residual_connection: bool = False,
594-
# _from_deprecated_attn_block: bool = False,
595-
# processor: Optional["AttnProcessor"] = WanAttnProcessor2_0
596-
# out_dim: int = None,
597-
# out_context_dim: int = None,
598-
# context_pre_only=None,
599-
# pre_only=False,
600-
# elementwise_affine: bool = True,
601-
# is_causal: bool = False,
581+
582+
if attention_kernel == "cudnn_flash_te" or attention_kernel == "dot_product":
583+
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
602584

603585
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
604586
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
587+
605588
self.dim_head = dim_head
606589
self.heads = heads
607590
self.inner_dim = dim_head * heads
@@ -717,7 +700,8 @@ def __call__(
717700
encoder_hidden_states: jax.Array,
718701
rotary_emb: Optional[jax.Array] = None
719702
) -> jax.Array:
720-
batch_size = hidden_states.shape[0]
703+
dtype = hidden_states.dtype
704+
# batch_size = hidden_states.shape[0]
721705
if encoder_hidden_states is None:
722706
encoder_hidden_states = hidden_states
723707
query_proj = self.query(hidden_states)
@@ -735,35 +719,15 @@ def __call__(
735719
key_proj = _unflatten_heads(key_proj, self.heads)
736720
if rotary_emb is not None:
737721
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
738-
#breakpoint()
739722
query_proj = _reshape_heads_to_head_dim(query_proj)
740723
key_proj = _reshape_heads_to_head_dim(key_proj)
741724
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
742-
breakpoint()
743-
725+
attn_output = attn_output.astype(dtype=dtype)
744726

727+
hidden_states = self.proj_attn(hidden_states)
728+
return hidden_states
745729

746730

747-
def setup(self):
748-
if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None:
749-
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
750-
inner_dim = self.dim_head * self.heads
751-
scale = self.dim_head**-0.5
752-
753-
self.attention_op = NNXAttentionOp(
754-
mesh=self.mesh,
755-
attention_kernel=self.attention_kernel,
756-
scale=scale,
757-
heads=self.heads,
758-
dim_head=self.dim_head,
759-
flash_min_seq_length=self.flash_min_seq_length,
760-
use_memory_efficient_attention=self.use_memory_efficient_attention,
761-
split_head_dim=self.split_head_dim,
762-
flash_block_sizes=self.flash_block_sizes,
763-
dtype=self.dtype,
764-
float32_qk_product=False,
765-
)
766-
767731
class FlaxFluxAttention(nn.Module):
768732
query_dim: int
769733
heads: int = 8

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,25 @@
1414
limitations under the License.
1515
"""
1616

17+
import os
1718
import jax
1819
import jax.numpy as jnp
1920
import unittest
2021
from absl.testing import absltest
2122
from flax import nnx
23+
from jax.sharding import Mesh
2224

25+
from .. import pyconfig
26+
from ..max_utils import (
27+
create_device_mesh,
28+
get_flash_block_sizes
29+
)
2330
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed, WanTimeTextImageEmbedding
2431
from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection
2532
from ..models.normalization_flax import FP32LayerNorm
33+
from ..models.attention_flax import FlaxWanAttention
34+
35+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
2636

2737
class WanTransformerTest(unittest.TestCase):
2838
def setUp(self):
@@ -108,6 +118,76 @@ def test_wan_time_text_embedding(self):
108118
assert temb.shape == (batch_size, dim)
109119
assert timestep_proj.shape == (batch_size, time_proj_dim)
110120
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
121+
122+
def test_wan_attention(self):
123+
pyconfig.initialize(
124+
[
125+
None,
126+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
127+
],
128+
unittest=True
129+
)
130+
config = pyconfig.config
131+
132+
batch_size = 1
133+
channels = 16
134+
frames = 21
135+
height = 90
136+
width = 160
137+
hidden_states_shape = (batch_size, frames, height, width, channels)
138+
dummy_hidden_states = jnp.ones(hidden_states_shape)
139+
wan_rot_embed = WanRotaryPosEmbed(
140+
attention_head_dim=128,
141+
patch_size=[1, 2, 2],
142+
max_seq_len=1024
143+
)
144+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
145+
146+
key = jax.random.key(0)
147+
rngs = nnx.Rngs(key)
148+
devices_array = create_device_mesh(config)
149+
150+
flash_block_sizes = get_flash_block_sizes(config)
151+
152+
mesh = Mesh(devices_array, config.mesh_axes)
153+
batch_size = 1
154+
query_dim = 5120
155+
attention = FlaxWanAttention(
156+
rngs=rngs,
157+
query_dim=query_dim,
158+
heads=40,
159+
dim_head=128,
160+
attention_kernel="flash",
161+
mesh=mesh,
162+
flash_block_sizes=flash_block_sizes,
163+
)
164+
165+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
166+
167+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
168+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
169+
170+
dummy_output = attention(
171+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
172+
)
173+
assert dummy_output.shape == dummy_hidden_states_shape
174+
175+
# dot product
176+
try:
177+
attention = FlaxWanAttention(
178+
rngs=rngs,
179+
query_dim=query_dim,
180+
heads=40,
181+
dim_head=128,
182+
attention_kernel="dot_product",
183+
split_head_dim=True,
184+
mesh=mesh,
185+
flash_block_sizes=flash_block_sizes,
186+
)
187+
except NotImplementedError as e:
188+
pass
189+
190+
111191

112192
if __name__ == "__main__":
113193
absltest.main()

0 commit comments

Comments
 (0)