1414 limitations under the License.
1515"""
1616
17- from typing import Tuple , Optional
17+ from typing import Tuple , Optional , Dict , Union , Any
1818import jax
1919import jax .numpy as jnp
2020from flax import nnx
2121from .... import common_types , max_logging
2222from ...modeling_flax_utils import FlaxModelMixin
23- from ....configuration_utils import ConfigMixin
24- from ...embeddings_flax import get_1d_rotary_pos_embed
23+ from ....configuration_utils import ConfigMixin , register_to_config
24+ from ...embeddings_flax import get_1d_rotary_pos_embed , NNXFlaxTimesteps , NNXTimestepEmbedding
2525
2626BlockSizes = common_types .BlockSizes
2727
@@ -65,7 +65,7 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
6565 cumulative_sizes = jnp .cumsum (jnp .array (sizes ))
6666 split_indices = cumulative_sizes [:- 1 ]
6767 freqs_split = jnp .split (self .freqs , split_indices , axis = 1 )
68-
68+
6969 freqs_f = jnp .expand_dims (jnp .expand_dims (freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
7070 freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ]))
7171
@@ -80,6 +80,40 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
8080 return freqs_final
8181
8282
83+ class WanTimeTextImageEmbedding (nnx .Module ):
84+ def __init__ (
85+ self ,
86+ rngs : nnx .Rngs ,
87+ dim : int ,
88+ time_freq_dim : int ,
89+ time_proj_dim : int ,
90+ text_embed_dim : int ,
91+ image_embed_dim : Optional [int ] = None ,
92+ pos_embed_seq_len : Optional [int ] = None ,
93+ dtype : jnp .dtype = jnp .float32 ,
94+ weights_dtype : jnp .dtype = jnp .float32 ,
95+ precision : jax .lax .Precision = None ,
96+ ):
97+ self .timesteps_proj = NNXFlaxTimesteps (
98+ dim = time_freq_dim , flip_sin_to_cos = True , freq_shift = 0
99+ )
100+ self .time_embedder = NNXTimestepEmbedding (
101+ rngs = rngs , in_channels = time_freq_dim , time_embed_dim = dim ,
102+ dtype = dtype , weights_dtype = weights_dtype , precision = precision
103+ )
104+
105+ def __call__ (
106+ self ,
107+ timestep : jax .Array ,
108+ encoder_hidden_states : jax .Array ,
109+ encoder_hidden_states_image : Optional [jax .Array ] = None
110+ ):
111+ timestep = self .timesteps_proj (timestep )
112+ temb = self .time_embedder (timestep )
113+ breakpoint ()
114+
115+
116+
83117class WanTransformer3DModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
84118 def __init__ (
85119 self ,
@@ -120,25 +154,28 @@ def __init__(
120154
121155
122156class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
123-
157+
158+ @register_to_config
124159 def __init__ (
125160 self ,
126161 rngs : nnx .Rngs ,
127162 model_type = "t2v" ,
128- patch_size = (1 , 2 , 2 ),
129- text_len = 512 ,
130- in_dim = 16 ,
131- dim = 2048 ,
132- ffn_dim = 8192 ,
133- freq_dim = 256 ,
134- text_dim = 4096 ,
135- out_dim = 16 ,
136- num_heads = 16 ,
137- num_layers = 32 ,
138- window_size = (- 1 , - 1 ),
139- qk_norm = True ,
140- cross_attn_norm = True ,
141- eps = 1e-6 ,
163+ patch_size : Tuple [int ] = (1 , 2 , 2 ),
164+ num_attention_heads : int = 40 ,
165+ attention_head_dim : int = 128 ,
166+ in_channels : int = 16 ,
167+ out_channels : int = 16 ,
168+ text_dim : int = 4096 ,
169+ freq_dim : int = 256 ,
170+ ffn_dim : int = 13824 ,
171+ num_layers : int = 40 ,
172+ cross_attn_norm : bool = True ,
173+ qk_norm : Optional [str ] = "rms_norm_across_heads" ,
174+ eps : float = 1e-6 ,
175+ image_dim : Optional [int ] = None ,
176+ added_kn_proj_dim : Optional [int ] = None ,
177+ rope_max_seq_len : int = 1024 ,
178+ pos_embed_seq_len : Optional [int ] = None ,
142179 flash_min_seq_length : int = 4096 ,
143180 flash_block_sizes : BlockSizes = None ,
144181 mesh : jax .sharding .Mesh = None ,
@@ -147,18 +184,62 @@ def __init__(
147184 precision : jax .lax .Precision = None ,
148185 attention : str = "dot_product" ,
149186 ):
150- self .path_embedding = nnx .Conv (
151- in_dim ,
152- dim ,
187+
188+ inner_dim = num_attention_heads * attention_head_dim
189+ out_channels = out_channels or in_channels
190+
191+ #1. Patch & position embedding
192+ self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
193+ self .patch_embedding = nnx .Conv (
194+ in_channels ,
195+ inner_dim ,
196+ rngs = rngs ,
153197 kernel_size = patch_size ,
154198 strides = patch_size ,
155199 dtype = dtype ,
156200 param_dtype = weights_dtype ,
157201 precision = precision ,
158202 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("batch" ,)),
159- rngs = rngs ,
160203 )
161204
162- def __call__ (self , x ):
163- x = self .path_embedding (x )
164- return x
205+ # 2. Condition embeddings
206+ # image_embedding_dim=1280 for I2V model
207+ self .condition_embedder = WanTimeTextImageEmbedding (
208+ rngs = rngs ,
209+ dim = inner_dim ,
210+ time_freq_dim = freq_dim ,
211+ time_proj_dim = inner_dim * 6 ,
212+ text_embed_dim = text_dim ,
213+ image_embed_dim = image_dim ,
214+ pos_embed_seq_len = pos_embed_seq_len
215+ )
216+
217+ def __call__ (
218+ self ,
219+ hidden_states : jax .Array ,
220+ timestep : jax .Array ,
221+ encoder_hidden_states : jax .Array ,
222+ encoder_hidden_states_image : Optional [jax .Array ] = None ,
223+ return_dict : bool = True ,
224+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
225+ ) -> Union [jax .Array , Dict [str , jax .Array ]]:
226+ batch_size , num_frames , height , width , num_channels = hidden_states .shape
227+ p_t , p_h , p_w = self .config .patch_size
228+ post_patch_num_frames = num_frames // p_t
229+ post_patch_height = height // p_h
230+ post_patch_width = width // p_w
231+
232+
233+ rotary_emb = self .rope (hidden_states )
234+ hidden_states = self .patch_embedding (hidden_states )
235+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
236+
237+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
238+ timestep , encoder_hidden_states , encoder_hidden_states_image
239+ )
240+ #hidden_states =
241+ # Torch shape: ([1, 5120, 21, 45, 80])
242+ # Jax shape: (1, 21, 45, 80, 5120) so channels is 5120
243+
244+
245+ return hidden_states
0 commit comments