1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ from typing import Tuple , Optional
18+ import jax
19+ import jax .numpy as jnp
20+ from flax import nnx
21+ from .... import common_types , max_logging
22+ from ...modeling_flax_utils import FlaxModelMixin
23+ from ....configuration_utils import ConfigMixin
24+ from ...embeddings_flax import get_1d_rotary_pos_embed
25+
26+ BlockSizes = common_types .BlockSizes
27+
28+ class WanRotaryPosEmbed (nnx .Module ):
29+ def __init__ (
30+ self ,
31+ attention_head_dim : int ,
32+ patch_size : Tuple [int , int , int ],
33+ max_seq_len : int ,
34+ theta : float = 10000.0
35+ ):
36+ self .attention_head_dim = attention_head_dim
37+ self .patch_size = patch_size
38+ self .max_seq_len = max_seq_len
39+
40+ h_dim = w_dim = 2 * (attention_head_dim // 6 )
41+ t_dim = attention_head_dim - h_dim - w_dim
42+
43+ freqs = []
44+ for dim in [t_dim , h_dim , w_dim ]:
45+ freq = get_1d_rotary_pos_embed (
46+ dim ,
47+ self .max_seq_len ,
48+ theta ,
49+ freqs_dtype = jnp .float64 ,
50+ use_real = False
51+ )
52+ freqs .append (freq )
53+ self .freqs = jnp .concatenate (freqs , axis = 1 )
54+
55+ def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
56+ _ , num_frames , height , width , _ = hidden_states .shape
57+ p_t , p_h , p_w = self .patch_size
58+ ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
59+
60+ sizes = [
61+ self .attention_head_dim // 2 - 2 * (self .attention_head_dim // 6 ),
62+ self .attention_head_dim // 6 ,
63+ self .attention_head_dim // 6 ,
64+ ]
65+ cumulative_sizes = jnp .cumsum (jnp .array (sizes ))
66+ split_indices = cumulative_sizes [:- 1 ]
67+ freqs_split = jnp .split (self .freqs , split_indices , axis = 1 )
68+
69+ freqs_f = jnp .expand_dims (jnp .expand_dims (freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
70+ freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ]))
71+
72+ freqs_h = jnp .expand_dims (jnp .expand_dims (freqs_split [1 ][:pph ], axis = 0 ), axis = 2 )
73+ freqs_h = jnp .broadcast_to (freqs_h , (ppf , pph , ppw , freqs_split [1 ].shape [- 1 ]))
74+
75+ freqs_w = jnp .expand_dims (jnp .expand_dims (freqs_split [2 ][:ppw ], axis = 0 ), axis = 1 )
76+ freqs_w = jnp .broadcast_to (freqs_w , (ppf , pph , ppw , freqs_split [2 ].shape [- 1 ]))
77+
78+ freqs_concat = jnp .concatenate ([freqs_f , freqs_h , freqs_w ], axis = - 1 )
79+ freqs_final = jnp .reshape (freqs_concat , (1 , 1 , ppf * pph * ppw , - 1 ))
80+ return freqs_final
81+
82+
83+ class WanTransformer3DModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
84+ def __init__ (
85+ self ,
86+ rngs : nnx .Rngs ,
87+ patch_size : Tuple [int ] = (1 , 2 , 2 ),
88+ num_attention_heads : int = 40 ,
89+ attention_head_dim : int = 128 ,
90+ in_channels : int = 16 ,
91+ out_channels : int = 16 ,
92+ text_dim : int = 4096 ,
93+ freq_dim : int = 256 ,
94+ ffn_dim : int = 13824 ,
95+ num_layers : int = 40 ,
96+ cross_attn_norm : bool = True ,
97+ qk_norm : Optional [str ] = "rms_norm_across_heads" ,
98+ eps : float = 1e-6 ,
99+ image_dim : Optional [int ] = None ,
100+ added_kv_proj_dim : Optional [int ] = None ,
101+ rope_max_seq_len : int = 1024 ,
102+ pos_embed_seq_len : Optional [int ] = None ,
103+ flash_min_seq_length : int = 4096 ,
104+ flash_block_sizes : BlockSizes = None ,
105+ mesh : jax .sharding .Mesh = None ,
106+ dtype : jnp .dtype = jnp .float32 ,
107+ weights_dtype : jnp .dtype = jnp .float32 ,
108+ precision : jax .lax .Precision = None ,
109+ attention : str = "dot_product" ,
110+ ):
111+ inner_dim = num_attention_heads * attention_head_dim
112+ out_channels = out_channels or in_channels
113+
114+ #1. Patch & position embedding
115+ self .rope = WanRotaryPosEmbed (
116+ attention_head_dim ,
117+ patch_size ,
118+ rope_max_seq_len
119+ )
120+
121+
122+ class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
123+
124+ def __init__ (
125+ self ,
126+ rngs : nnx .Rngs ,
127+ model_type = "t2v" ,
128+ patch_size = (1 , 2 , 2 ),
129+ text_len = 512 ,
130+ in_dim = 16 ,
131+ dim = 2048 ,
132+ ffn_dim = 8192 ,
133+ freq_dim = 256 ,
134+ text_dim = 4096 ,
135+ out_dim = 16 ,
136+ num_heads = 16 ,
137+ num_layers = 32 ,
138+ window_size = (- 1 , - 1 ),
139+ qk_norm = True ,
140+ cross_attn_norm = True ,
141+ eps = 1e-6 ,
142+ flash_min_seq_length : int = 4096 ,
143+ flash_block_sizes : BlockSizes = None ,
144+ mesh : jax .sharding .Mesh = None ,
145+ dtype : jnp .dtype = jnp .float32 ,
146+ weights_dtype : jnp .dtype = jnp .float32 ,
147+ precision : jax .lax .Precision = None ,
148+ attention : str = "dot_product" ,
149+ ):
150+ self .path_embedding = nnx .Conv (
151+ in_dim ,
152+ dim ,
153+ kernel_size = patch_size ,
154+ strides = patch_size ,
155+ dtype = dtype ,
156+ param_dtype = weights_dtype ,
157+ precision = precision ,
158+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("batch" ,)),
159+ rngs = rngs ,
160+ )
161+
162+ def __call__ (self , x ):
163+ x = self .path_embedding (x )
164+ return x
0 commit comments