Skip to content

Commit 1ae2616

Browse files
committed
Adds a VACE transformer block
1 parent 4836217 commit 1ae2616

2 files changed

Lines changed: 391 additions & 0 deletions

File tree

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
from typing import Tuple
17+
18+
from flax import nnx
19+
import jax
20+
from jax.ad_checkpoint import checkpoint_name
21+
import jax.numpy as jnp
22+
from jax.sharding import PartitionSpec
23+
24+
from .... import common_types
25+
from ...attention_flax import FlaxWanAttention
26+
from ...normalization_flax import FP32LayerNorm
27+
from .transformer_wan import WanFeedForward
28+
29+
BlockSizes = common_types.BlockSizes
30+
31+
32+
class WanVACETransformerBlock(nnx.Module):
33+
"""Attention block for VACE.
34+
35+
Processes the conditioning signals and produces latent codes that can be
36+
summed to the main branch of WAN.
37+
38+
Based on
39+
https://github.com/huggingface/diffusers/blob/be3c2a0667493022f17d756ca3dba631d28dfb40/src/diffusers/models/transformers/transformer_wan_vace.py#L41C7-L41C30
40+
"""
41+
42+
def __init__(
43+
self,
44+
rngs: nnx.Rngs,
45+
*,
46+
dim: int,
47+
ffn_dim: int,
48+
num_heads: int,
49+
qk_norm: str = "rms_norm_across_heads",
50+
cross_attn_norm: bool = False,
51+
eps: float = 1e-6,
52+
flash_min_seq_length: int = 4096,
53+
flash_block_sizes: BlockSizes | None = None,
54+
mesh: jax.sharding.Mesh | None = None,
55+
dtype: jnp.dtype = jnp.float32,
56+
weights_dtype: jnp.dtype = jnp.float32,
57+
precision: jax.lax.Precision | None = None,
58+
attention: str = "dot_product",
59+
dropout: float = 0.0,
60+
apply_input_projection: bool = False,
61+
apply_output_projection: bool = False,
62+
):
63+
"""Sets up the model.
64+
65+
Args:
66+
rngs: Random number generator.
67+
dim: Internal dimension of the block.
68+
ffn_dim: Dimension of the feed-forward network.
69+
num_heads: Number of attention heads.
70+
qk_norm: Whether to apply RMSNorm to the query and key vectors.
71+
cross_attn_norm: Whether to apply layer normalization before
72+
cross-attention (only True supported).
73+
eps: Epsilon value for normalization.
74+
flash_min_seq_length: Minimum sequence length for flash attention.
75+
flash_block_sizes: Block sizes for flash attention.
76+
mesh: Sharding topology.
77+
dtype: Data type for the computation.
78+
weights_dtype: Data type for parameter initializers (see param_dtype in
79+
nnx.Linear).
80+
precision: Precision for the computation.
81+
attention: Type of attention to use.
82+
dropout: Dropout rate.
83+
apply_input_projection: Whether to apply a linear projection to the
84+
inputs.
85+
apply_output_projection: Whether to apply an output projection before
86+
outputting the result.
87+
"""
88+
89+
self.apply_input_projection = apply_input_projection
90+
self.apply_output_projection = apply_output_projection
91+
92+
# 1. Input projection
93+
self.proj_in = nnx.data([None])
94+
if apply_input_projection:
95+
self.proj_in = nnx.Linear(
96+
rngs=rngs,
97+
in_features=dim,
98+
out_features=dim,
99+
dtype=dtype,
100+
param_dtype=weights_dtype,
101+
precision=precision,
102+
kernel_init=nnx.with_partitioning(
103+
nnx.initializers.xavier_uniform(), ("embed", None)
104+
),
105+
)
106+
107+
# 2. Self-attention
108+
self.norm1 = FP32LayerNorm(
109+
rngs=rngs, dim=dim, eps=eps, elementwise_affine=False
110+
)
111+
self.attn1 = FlaxWanAttention(
112+
rngs=rngs,
113+
query_dim=dim,
114+
heads=num_heads,
115+
dim_head=dim // num_heads,
116+
qk_norm=qk_norm,
117+
eps=eps,
118+
flash_min_seq_length=flash_min_seq_length,
119+
flash_block_sizes=flash_block_sizes,
120+
mesh=mesh,
121+
dtype=dtype,
122+
weights_dtype=weights_dtype,
123+
precision=precision,
124+
attention_kernel=attention,
125+
dropout=dropout,
126+
residual_checkpoint_name="self_attn",
127+
)
128+
129+
# 3. Cross-attention
130+
self.attn2 = FlaxWanAttention(
131+
rngs=rngs,
132+
query_dim=dim,
133+
heads=num_heads,
134+
dim_head=dim // num_heads,
135+
qk_norm=qk_norm,
136+
eps=eps,
137+
flash_min_seq_length=flash_min_seq_length,
138+
flash_block_sizes=flash_block_sizes,
139+
mesh=mesh,
140+
dtype=dtype,
141+
weights_dtype=weights_dtype,
142+
precision=precision,
143+
attention_kernel=attention,
144+
dropout=dropout,
145+
residual_checkpoint_name="cross_attn",
146+
)
147+
assert cross_attn_norm is True, "cross_attn_norm must be True"
148+
self.norm2 = FP32LayerNorm(
149+
rngs=rngs, dim=dim, eps=eps, elementwise_affine=True
150+
)
151+
152+
# 4. Feed-forward
153+
self.ffn = WanFeedForward(
154+
rngs=rngs,
155+
dim=dim,
156+
inner_dim=ffn_dim,
157+
activation_fn="gelu-approximate",
158+
dtype=dtype,
159+
weights_dtype=weights_dtype,
160+
precision=precision,
161+
dropout=dropout,
162+
)
163+
164+
self.norm3 = FP32LayerNorm(
165+
rngs=rngs, dim=dim, eps=eps, elementwise_affine=False
166+
)
167+
168+
# 5. Output projection
169+
self.proj_out = nnx.data([None])
170+
if apply_output_projection:
171+
self.proj_out = nnx.Linear(
172+
rngs=rngs,
173+
in_features=dim,
174+
out_features=dim,
175+
dtype=dtype,
176+
param_dtype=weights_dtype,
177+
precision=precision,
178+
kernel_init=nnx.with_partitioning(
179+
nnx.initializers.xavier_uniform(), ("embed", None)
180+
),
181+
)
182+
183+
key = rngs.params()
184+
self.adaln_scale_shift_table = nnx.Param(
185+
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
186+
)
187+
188+
def __call__(
189+
self,
190+
*,
191+
hidden_states: jax.Array,
192+
encoder_hidden_states: jax.Array,
193+
control_hidden_states: jax.Array,
194+
temb: jax.Array,
195+
rotary_emb: jax.Array,
196+
deterministic: bool = True,
197+
rngs: nnx.Rngs | None = None,
198+
) -> Tuple[jax.Array, jax.Array]:
199+
if self.apply_input_projection:
200+
control_hidden_states = self.proj_in(control_hidden_states)
201+
control_hidden_states = control_hidden_states + hidden_states
202+
203+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
204+
jnp.split(
205+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
206+
)
207+
)
208+
209+
control_hidden_states = jax.lax.with_sharding_constraint(
210+
control_hidden_states,
211+
PartitionSpec("data", "fsdp", "tensor"),
212+
)
213+
control_hidden_states = checkpoint_name(
214+
control_hidden_states, "control_hidden_states"
215+
)
216+
encoder_hidden_states = jax.lax.with_sharding_constraint(
217+
encoder_hidden_states,
218+
PartitionSpec("data", "fsdp", None),
219+
)
220+
221+
# 1. Self-attention
222+
with jax.named_scope("attn1"):
223+
norm_hidden_states = (
224+
self.norm1(control_hidden_states.astype(jnp.float32))
225+
* (1 + scale_msa)
226+
+ shift_msa
227+
).astype(control_hidden_states.dtype)
228+
attn_output = self.attn1(
229+
hidden_states=norm_hidden_states,
230+
encoder_hidden_states=norm_hidden_states,
231+
rotary_emb=rotary_emb,
232+
deterministic=deterministic,
233+
rngs=rngs,
234+
)
235+
control_hidden_states = (
236+
control_hidden_states.astype(jnp.float32) + attn_output * gate_msa
237+
).astype(control_hidden_states.dtype)
238+
239+
# 2. Cross-attention
240+
with jax.named_scope("attn2"):
241+
norm_hidden_states = self.norm2(
242+
control_hidden_states.astype(jnp.float32)
243+
).astype(control_hidden_states.dtype)
244+
attn_output = self.attn2(
245+
hidden_states=norm_hidden_states,
246+
encoder_hidden_states=encoder_hidden_states,
247+
deterministic=deterministic,
248+
rngs=rngs,
249+
)
250+
control_hidden_states = control_hidden_states + attn_output
251+
252+
# 3. Feed-forward
253+
with jax.named_scope("ffn"):
254+
norm_hidden_states = (
255+
self.norm3(control_hidden_states.astype(jnp.float32))
256+
* (1 + c_scale_msa)
257+
+ c_shift_msa
258+
).astype(control_hidden_states.dtype)
259+
ff_output = self.ffn(
260+
norm_hidden_states, deterministic=deterministic, rngs=rngs
261+
)
262+
control_hidden_states = (
263+
control_hidden_states.astype(jnp.float32)
264+
+ ff_output.astype(jnp.float32) * c_gate_msa
265+
).astype(control_hidden_states.dtype)
266+
conditioning_states = None
267+
if self.apply_output_projection:
268+
conditioning_states = self.proj_out(control_hidden_states)
269+
270+
return conditioning_states, control_hidden_states
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import jax
19+
import jax.numpy as jnp
20+
import unittest
21+
from absl.testing import absltest
22+
from flax import nnx
23+
from jax.sharding import Mesh
24+
25+
from .. import pyconfig
26+
from ..max_utils import (create_device_mesh, get_flash_block_sizes)
27+
from ..models.wan.transformers.transformer_wan import (
28+
WanRotaryPosEmbed,
29+
)
30+
from ..models.wan.transformers.transformer_wan_vace import (
31+
WanVACETransformerBlock,
32+
)
33+
import qwix
34+
import flax
35+
36+
flax.config.update("flax_always_shard_variable", False)
37+
RealQtRule = qwix.QtRule
38+
39+
40+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
41+
42+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
43+
44+
45+
class WanVaceTransformerTest(unittest.TestCase):
46+
def test_wan_vace_block_returns_the_correct_shape(self):
47+
key = jax.random.key(0)
48+
rngs = nnx.Rngs(key)
49+
pyconfig.initialize(
50+
[
51+
None,
52+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
53+
],
54+
unittest=True,
55+
)
56+
config = pyconfig.config
57+
58+
devices_array = create_device_mesh(config)
59+
60+
flash_block_sizes = get_flash_block_sizes(config)
61+
62+
mesh = Mesh(devices_array, config.mesh_axes)
63+
64+
dim = 5120
65+
ffn_dim = 13824
66+
num_heads = 40
67+
qk_norm = "rms_norm_across_heads"
68+
cross_attn_norm = True
69+
eps = 1e-6
70+
71+
batch_size = 1
72+
channels = 16
73+
frames = 21
74+
height = 90
75+
width = 160
76+
hidden_dim = 75600
77+
78+
# for rotary post embed.
79+
hidden_states_shape = (batch_size, frames, height, width, channels)
80+
dummy_hidden_states = jnp.ones(hidden_states_shape)
81+
82+
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
83+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
84+
assert dummy_rotary_emb.shape == (batch_size, 1, hidden_dim, 64)
85+
86+
# for transformer block
87+
dummy_hidden_states = jnp.ones((batch_size, hidden_dim, dim))
88+
89+
dummy_control_hidden_states = jnp.ones((batch_size, hidden_dim, dim))
90+
91+
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
92+
93+
dummy_temb = jnp.ones((batch_size, 6, dim))
94+
95+
wan_vace_block = WanVACETransformerBlock(
96+
rngs=rngs,
97+
dim=dim,
98+
ffn_dim=ffn_dim,
99+
num_heads=num_heads,
100+
qk_norm=qk_norm,
101+
cross_attn_norm=cross_attn_norm,
102+
eps=eps,
103+
attention="flash",
104+
mesh=mesh,
105+
flash_block_sizes=flash_block_sizes,
106+
apply_input_projection=True,
107+
apply_output_projection=True,
108+
)
109+
with mesh:
110+
conditioning_states, control_hidden_states = wan_vace_block(
111+
hidden_states=dummy_hidden_states,
112+
encoder_hidden_states=dummy_encoder_hidden_states,
113+
control_hidden_states=dummy_control_hidden_states,
114+
temb=dummy_temb,
115+
rotary_emb=dummy_rotary_emb,
116+
)
117+
assert conditioning_states.shape == dummy_hidden_states.shape
118+
assert control_hidden_states.shape == dummy_hidden_states.shape
119+
120+
if __name__ == "__main__":
121+
absltest.main()

0 commit comments

Comments
 (0)