2828 NNXPixArtAlphaTextProjection
2929)
3030from ...normalization_flax import FP32LayerNorm
31+ from ...attention_flax import FlaxWanAttention
3132
3233BlockSizes = common_types .BlockSizes
3334
@@ -181,6 +182,89 @@ def __init__(
181182 rope_max_seq_len
182183 )
183184
185+ class ApproximateGELU (nnx .Module ):
186+ r"""
187+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
188+ [paper](https://arxiv.org/abs/1606.08415).
189+ """
190+ def __init__ (
191+ self ,
192+ rngs : nnx .Rngs ,
193+ dim_in : int ,
194+ dim_out : int ,
195+ bias : bool ,
196+ dtype : jnp .dtype = jnp .float32 ,
197+ weights_dtype : jnp .dtype = jnp .float32 ,
198+ precision : jax .lax .Precision = None ,
199+ ):
200+ self .proj = nnx .Linear (
201+ rngs = rngs ,
202+ in_features = dim_in ,
203+ out_features = dim_out ,
204+ use_bias = bias ,
205+ dtype = dtype ,
206+ param_dtype = weights_dtype ,
207+ precision = precision ,
208+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "mlp" ,)),
209+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
210+ )
211+
212+ def __call__ (self , x : jax .Array ) -> jax .Array :
213+ x = self .proj (x )
214+ return x * jax .nn .sigmoid (1.702 * x )
215+
216+
217+ class WanFeedForward (nnx .Module ):
218+ def __init__ (
219+ self ,
220+ rngs : nnx .Rngs ,
221+ dim : int ,
222+ dim_out : Optional [int ] = None ,
223+ mult : int = 4 ,
224+ dropout : float = 0.0 ,
225+ activation_fn : str = "geglu" ,
226+ final_dropout : bool = False ,
227+ inner_dim : int = None ,
228+ bias : bool = True ,
229+ dtype : jnp .dtype = jnp .float32 ,
230+ weights_dtype : jnp .dtype = jnp .float32 ,
231+ precision : jax .lax .Precision = None ,
232+ ):
233+ if inner_dim is None :
234+ inner_dim = int (dim * mult )
235+ dim_out = dim_out if dim_out is not None else dim
236+
237+ self .act_fn = None
238+ if activation_fn == "gelu-approximate" :
239+ self .act_fn = ApproximateGELU (
240+ rngs = rngs ,
241+ dim_in = dim ,
242+ dim_out = inner_dim ,
243+ bias = bias ,
244+ dtype = dtype ,
245+ weights_dtype = weights_dtype ,
246+ precision = precision
247+ )
248+ else :
249+ raise NotImplementedError (f"{ activation_fn } is not implemented." )
250+
251+ self .proj_out = nnx .Linear (
252+ rngs = rngs ,
253+ in_features = inner_dim ,
254+ out_features = dim_out ,
255+ use_bias = bias ,
256+ dtype = dtype ,
257+ param_dtype = weights_dtype ,
258+ precision = precision ,
259+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("mlp" , "embed" ,)),
260+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
261+ )
262+
263+ def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
264+ hidden_states = self .act_fn (hidden_states )
265+ return self .proj_out (hidden_states )
266+
267+
184268
185269class WanTransformerBlock (nnx .Module ):
186270 def __init__ (
@@ -192,17 +276,107 @@ def __init__(
192276 qk_norm : str = "rms_norm_across_heads" ,
193277 cross_attn_norm : bool = False ,
194278 eps : float = 1e-6 ,
195- added_kv_proj_dim : Optional [int ] = None
279+ # In torch, this is none, so it can be ignored.
280+ # added_kv_proj_dim: Optional[int] = None,
281+ flash_min_seq_length : int = 4096 ,
282+ flash_block_sizes : BlockSizes = None ,
283+ mesh : jax .sharding .Mesh = None ,
284+ dtype : jnp .dtype = jnp .float32 ,
285+ weights_dtype : jnp .dtype = jnp .float32 ,
286+ precision : jax .lax .Precision = None ,
287+ attention : str = "dot_product" ,
288+
196289 ):
290+
291+ # 1. Self-attention
197292 self .norm1 = FP32LayerNorm (
293+ rngs = rngs ,
198294 dim = dim ,
199295 eps = eps ,
200296 elementwise_affine = False
201297 )
298+ self .attn1 = FlaxWanAttention (
299+ rngs = rngs ,
300+ query_dim = dim ,
301+ heads = num_heads ,
302+ dim_head = dim // num_heads ,
303+ qk_norm = qk_norm ,
304+ eps = eps ,
305+ flash_min_seq_length = flash_min_seq_length ,
306+ flash_block_sizes = flash_block_sizes ,
307+ mesh = mesh ,
308+ dtype = dtype ,
309+ weights_dtype = weights_dtype ,
310+ precision = precision ,
311+ attention_kernel = attention
312+ )
313+
314+ # 1. Cross-attention
315+ self .attn2 = FlaxWanAttention (
316+ rngs = rngs ,
317+ query_dim = dim ,
318+ heads = num_heads ,
319+ dim_head = dim // num_heads ,
320+ qk_norm = qk_norm ,
321+ eps = eps ,
322+ flash_min_seq_length = flash_min_seq_length ,
323+ flash_block_sizes = flash_block_sizes ,
324+ mesh = mesh ,
325+ dtype = dtype ,
326+ weights_dtype = weights_dtype ,
327+ precision = precision ,
328+ attention_kernel = attention
329+ )
330+ assert cross_attn_norm == True
331+ self .norm2 = FP32LayerNorm (
332+ rngs = rngs ,
333+ dim = dim ,
334+ eps = eps ,
335+ elementwise_affine = True
336+ )
337+
338+ # 3. Feed-forward
339+ self .ffn = WanFeedForward (
340+ rngs = rngs ,
341+ dim = dim ,
342+ inner_dim = ffn_dim ,
343+ activation_fn = "gelu-approximate" ,
344+ dtype = dtype ,
345+ weights_dtype = weights_dtype ,
346+ precision = precision
347+ )
348+ self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
349+
350+ key = rngs .params ()
351+ self .scale_shift_table = nnx .Param (jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 )
202352
203- def __call__ (self ):
204- pass
353+ def __call__ (
354+ self ,
355+ hidden_states : jax .Array ,
356+ encoder_hidden_states : jax .Array ,
357+ temb : jax .Array ,
358+ rotary_emb : jax .Array
359+ ):
360+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
361+ (self .scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
362+ )
363+
364+ # 1. Self-attention
365+ norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (hidden_states .dtype )
366+ attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
367+ hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
368+
369+ # 2. Cross-attention
370+ norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 ))
371+ attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
372+ hidden_states = hidden_states + attn_output
373+
374+ # 3. Feed-forward
375+ norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (hidden_states .dtype )
205376
377+ ff_output = self .ffn (norm_hidden_states )
378+ hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (hidden_states .dtype )
379+ return hidden_states
206380
207381
208382class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
@@ -269,7 +443,22 @@ def __init__(
269443 # 3. Transformer blocks
270444 blocks = []
271445 for _ in range (num_layers ):
272- block = WanTransformerBlock ()
446+ block = WanTransformerBlock (
447+ rngs = rngs ,
448+ dim = inner_dim ,
449+ ffn_dim = ffn_dim ,
450+ num_attention_heads = num_attention_heads ,
451+ qk_norm = qk_norm ,
452+ cross_attn_norm = cross_attn_norm ,
453+ eps = eps ,
454+ flash_min_seq_length = flash_min_seq_length ,
455+ flash_block_sizes = flash_block_sizes ,
456+ mesh = mesh ,
457+ dtype = dtype ,
458+ weights_dtype = weights_dtype ,
459+ precision = precision ,
460+ attention = attention
461+ )
273462 blocks .append (block )
274463 self .blocks = blocks
275464
@@ -301,8 +490,9 @@ def __call__(
301490 if encoder_hidden_states_image is not None :
302491 raise NotImplementedError ("img2vid is not yet implemented." )
303492
304- # for block in self.blocks:
305-
493+ for block in self .blocks :
494+ hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
495+ breakpoint ()
306496
307497
308498 return hidden_states
0 commit comments