Skip to content

Commit d449d1f

Browse files
add WanRotaryPosEmbed
1 parent b31b4ad commit d449d1f

4 files changed

Lines changed: 241 additions & 5 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,13 @@ def __call__(self, timesteps):
102102

103103

104104
def get_1d_rotary_pos_embed(
105-
dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32
105+
dim: int,
106+
pos: Union[jnp.array, int],
107+
theta: float = 10000.0,
108+
linear_factor=1.0,
109+
ntk_factor=1.0,
110+
freqs_dtype=jnp.float32,
111+
use_real: bool = True
106112
):
107113
"""
108114
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -115,10 +121,14 @@ def get_1d_rotary_pos_embed(
115121
theta = theta * ntk_factor
116122
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor
117123
freqs = jnp.outer(pos, freqs)
118-
freqs_cos = jnp.cos(freqs)
119-
freqs_sin = jnp.sin(freqs)
120-
out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1)
121-
124+
if use_real:
125+
# Flux
126+
freqs_cos = jnp.cos(freqs)
127+
freqs_sin = jnp.sin(freqs)
128+
out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1)
129+
else:
130+
# Wan 2.1
131+
out = jax.lax.complex(jnp.ones_like(freqs), freqs)
122132
return out
123133

124134

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple, Optional
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from .... import common_types, max_logging
22+
from ...modeling_flax_utils import FlaxModelMixin
23+
from ....configuration_utils import ConfigMixin
24+
from ...embeddings_flax import get_1d_rotary_pos_embed
25+
26+
BlockSizes = common_types.BlockSizes
27+
28+
class WanRotaryPosEmbed(nnx.Module):
29+
def __init__(
30+
self,
31+
attention_head_dim: int,
32+
patch_size: Tuple[int, int, int],
33+
max_seq_len: int,
34+
theta: float = 10000.0
35+
):
36+
self.attention_head_dim = attention_head_dim
37+
self.patch_size = patch_size
38+
self.max_seq_len = max_seq_len
39+
40+
h_dim = w_dim = 2 * (attention_head_dim // 6)
41+
t_dim = attention_head_dim - h_dim - w_dim
42+
43+
freqs = []
44+
for dim in [t_dim, h_dim, w_dim]:
45+
freq = get_1d_rotary_pos_embed(
46+
dim,
47+
self.max_seq_len,
48+
theta,
49+
freqs_dtype=jnp.float64,
50+
use_real=False
51+
)
52+
freqs.append(freq)
53+
self.freqs = jnp.concatenate(freqs, axis=1)
54+
55+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
56+
_, num_frames, height, width, _ = hidden_states.shape
57+
p_t, p_h, p_w = self.patch_size
58+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
59+
60+
sizes = [
61+
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
62+
self.attention_head_dim // 6,
63+
self.attention_head_dim // 6,
64+
]
65+
cumulative_sizes = jnp.cumsum(jnp.array(sizes))
66+
split_indices = cumulative_sizes[:-1]
67+
freqs_split = jnp.split(self.freqs, split_indices, axis=1)
68+
69+
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
70+
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
71+
72+
freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2)
73+
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1]))
74+
75+
freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1)
76+
freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1]))
77+
78+
freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1)
79+
freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1))
80+
return freqs_final
81+
82+
83+
class WanTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin):
84+
def __init__(
85+
self,
86+
rngs: nnx.Rngs,
87+
patch_size: Tuple[int] = (1, 2, 2),
88+
num_attention_heads: int = 40,
89+
attention_head_dim: int = 128,
90+
in_channels: int = 16,
91+
out_channels: int = 16,
92+
text_dim: int = 4096,
93+
freq_dim: int = 256,
94+
ffn_dim: int = 13824,
95+
num_layers: int = 40,
96+
cross_attn_norm: bool = True,
97+
qk_norm: Optional[str] = "rms_norm_across_heads",
98+
eps: float = 1e-6,
99+
image_dim: Optional[int] = None,
100+
added_kv_proj_dim: Optional[int] = None,
101+
rope_max_seq_len: int = 1024,
102+
pos_embed_seq_len: Optional[int] = None,
103+
flash_min_seq_length: int = 4096,
104+
flash_block_sizes: BlockSizes = None,
105+
mesh: jax.sharding.Mesh = None,
106+
dtype: jnp.dtype = jnp.float32,
107+
weights_dtype: jnp.dtype = jnp.float32,
108+
precision: jax.lax.Precision = None,
109+
attention: str = "dot_product",
110+
):
111+
inner_dim = num_attention_heads * attention_head_dim
112+
out_channels = out_channels or in_channels
113+
114+
#1. Patch & position embedding
115+
self.rope = WanRotaryPosEmbed(
116+
attention_head_dim,
117+
patch_size,
118+
rope_max_seq_len
119+
)
120+
121+
122+
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
123+
124+
def __init__(
125+
self,
126+
rngs: nnx.Rngs,
127+
model_type="t2v",
128+
patch_size=(1, 2, 2),
129+
text_len=512,
130+
in_dim=16,
131+
dim=2048,
132+
ffn_dim=8192,
133+
freq_dim=256,
134+
text_dim=4096,
135+
out_dim=16,
136+
num_heads=16,
137+
num_layers=32,
138+
window_size=(-1, -1),
139+
qk_norm=True,
140+
cross_attn_norm=True,
141+
eps=1e-6,
142+
flash_min_seq_length: int = 4096,
143+
flash_block_sizes: BlockSizes = None,
144+
mesh: jax.sharding.Mesh = None,
145+
dtype: jnp.dtype = jnp.float32,
146+
weights_dtype: jnp.dtype = jnp.float32,
147+
precision: jax.lax.Precision = None,
148+
attention: str = "dot_product",
149+
):
150+
self.path_embedding = nnx.Conv(
151+
in_dim,
152+
dim,
153+
kernel_size=patch_size,
154+
strides=patch_size,
155+
dtype=dtype,
156+
param_dtype=weights_dtype,
157+
precision=precision,
158+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("batch",)),
159+
rngs=rngs,
160+
)
161+
162+
def __call__(self, x):
163+
x = self.path_embedding(x)
164+
return x
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import jax
18+
import jax.numpy as jnp
19+
import unittest
20+
from absl.testing import absltest
21+
from flax import nnx
22+
23+
from ..models.wan.transformers.transformer_wan import WanRotaryPosEmbed
24+
25+
class WanTransformerTest(unittest.TestCase):
26+
def setUp(self):
27+
WanTransformerTest.dummy_data = {}
28+
29+
def test_rotary_pos_embed(self):
30+
batch_size = 1
31+
channels = 16
32+
frames = 21
33+
height = 90
34+
width = 160
35+
hidden_states_shape = (batch_size, frames, height, width, channels)
36+
dummy_hidden_states = jnp.ones(hidden_states_shape)
37+
wan_rot_embed = WanRotaryPosEmbed(
38+
attention_head_dim=128,
39+
patch_size=[1, 2, 2],
40+
max_seq_len=1024
41+
)
42+
dummy_output = wan_rot_embed(dummy_hidden_states)
43+
assert dummy_output.shape == (1, 1, 75600, 64)
44+
# output shape should be torch.Size([1, 1, 75600, 64])
45+
46+
if __name__ == "__main__":
47+
absltest.main()

0 commit comments

Comments
 (0)