4848}
4949
5050
51- class FlaxFusedLeakyReLU (nnx .Module ):
51+ class FusedLeakyReLU (nnx .Module ):
5252 """
5353 Fused LeakyRelu with scale factor and channel-wise bias.
5454 """
@@ -84,7 +84,7 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
8484 return x
8585
8686
87- class FlaxMotionConv2d (nnx .Module ):
87+ class MotionConv2d (nnx .Module ):
8888 """2-D convolution with EqualizedLR scaling and optional FusedLeakyReLU.
8989
9090 Weights are stored in PyTorch OIHW format (out, in, k, k) as raw nnx.Param
@@ -148,7 +148,7 @@ def __init__(
148148 self .bias = None
149149
150150 if self .use_activation :
151- self .act_fn = FlaxFusedLeakyReLU (
151+ self .act_fn = FusedLeakyReLU (
152152 rngs = rngs , bias_channels = out_channels , dtype = dtype , weights_dtype = weights_dtype
153153 )
154154 else :
@@ -205,11 +205,11 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
205205 return x
206206
207207
208- class FlaxMotionLinear (nnx .Module ):
208+ class MotionLinear (nnx .Module ):
209209 """Equalized-LR linear layer with optional FusedLeakyReLU.
210210
211211 Weights are stored in PyTorch (out, in) format as raw nnx.Param — same
212- reason as FlaxMotionConv2d . No sharding annotations needed (small layer).
212+ reason as MotionConv2d . No sharding annotations needed (small layer).
213213 """
214214
215215 def __init__ (
@@ -238,7 +238,7 @@ def __init__(
238238 self .bias = None
239239
240240 if self .use_activation :
241- self .act_fn = FlaxFusedLeakyReLU (rngs = rngs , bias_channels = out_dim , dtype = dtype , weights_dtype = weights_dtype )
241+ self .act_fn = FusedLeakyReLU (rngs = rngs , bias_channels = out_dim , dtype = dtype , weights_dtype = weights_dtype )
242242 else :
243243 self .act_fn = None
244244
@@ -258,7 +258,7 @@ def __call__(self, inputs: jax.Array, channel_dim: int = 1) -> jax.Array:
258258 return out
259259
260260
261- class FlaxMotionEncoderResBlock (nnx .Module ):
261+ class MotionEncoderResBlock (nnx .Module ):
262262
263263 def __init__ (
264264 self ,
@@ -276,7 +276,7 @@ def __init__(
276276 self .dtype = dtype
277277
278278 # 3 X 3 Conv + fused leaky ReLU
279- self .conv1 = FlaxMotionConv2d (
279+ self .conv1 = MotionConv2d (
280280 rngs ,
281281 in_channels ,
282282 in_channels ,
@@ -289,7 +289,7 @@ def __init__(
289289 )
290290
291291 # 3 X 3 Conv + downsample 2x + fused leaky ReLU
292- self .conv2 = FlaxMotionConv2d (
292+ self .conv2 = MotionConv2d (
293293 rngs ,
294294 in_channels ,
295295 out_channels ,
@@ -303,7 +303,7 @@ def __init__(
303303 )
304304
305305 # 1 X 1 Conv + downsample 2x in skip connection
306- self .conv_skip = FlaxMotionConv2d (
306+ self .conv_skip = MotionConv2d (
307307 rngs ,
308308 in_channels ,
309309 out_channels ,
@@ -327,7 +327,7 @@ def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array:
327327 return x_out
328328
329329
330- class FlaxWanAnimateMotionEncoder (nnx .Module ):
330+ class WanAnimateMotionEncoder (nnx .Module ):
331331 """Encodes a face video frame into a motion vector.
332332
333333 All weights in this network are small (the largest is 32×512→16) so
@@ -353,7 +353,7 @@ def __init__(
353353 if channels is None :
354354 channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
355355
356- self .conv_in = FlaxMotionConv2d (
356+ self .conv_in = MotionConv2d (
357357 rngs , 3 , channels [str (size )], 1 , use_activation = True , dtype = dtype , weights_dtype = weights_dtype
358358 )
359359
@@ -363,12 +363,12 @@ def __init__(
363363 for i in range (log_size , 2 , - 1 ):
364364 out_channels = channels [str (2 ** (i - 1 ))]
365365 res_blocks .append (
366- FlaxMotionEncoderResBlock (rngs , in_channels , out_channels , dtype = dtype , weights_dtype = weights_dtype )
366+ MotionEncoderResBlock (rngs , in_channels , out_channels , dtype = dtype , weights_dtype = weights_dtype )
367367 )
368368 in_channels = out_channels
369369 self .res_blocks = nnx .List (res_blocks )
370370
371- self .conv_out = FlaxMotionConv2d (
371+ self .conv_out = MotionConv2d (
372372 rngs ,
373373 in_channels ,
374374 style_dim ,
@@ -382,9 +382,9 @@ def __init__(
382382
383383 linears = []
384384 for _ in range (motion_blocks - 1 ):
385- linears .append (FlaxMotionLinear (rngs , style_dim , style_dim , dtype = dtype , weights_dtype = weights_dtype ))
385+ linears .append (MotionLinear (rngs , style_dim , style_dim , dtype = dtype , weights_dtype = weights_dtype ))
386386
387- linears .append (FlaxMotionLinear (rngs , style_dim , motion_dim , dtype = dtype , weights_dtype = weights_dtype ))
387+ linears .append (MotionLinear (rngs , style_dim , motion_dim , dtype = dtype , weights_dtype = weights_dtype ))
388388 self .motion_network = nnx .List (linears )
389389
390390 key = rngs .params ()
@@ -417,7 +417,7 @@ def __call__(self, face_image: jax.Array, channel_dim: int = 1) -> jax.Array:
417417 return motion_vec .astype (original_dtype )
418418
419419
420- class FlaxWanAnimateFaceEncoder (nnx .Module ):
420+ class WanAnimateFaceEncoder (nnx .Module ):
421421
422422 def __init__ (
423423 self ,
@@ -544,7 +544,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
544544 return x
545545
546546
547- class FlaxWanAnimateFaceBlockCrossAttention (nnx .Module ):
547+ class WanAnimateFaceBlockCrossAttention (nnx .Module ):
548548
549549 def __init__ (
550550 self ,
@@ -763,7 +763,7 @@ def __init__(
763763 weights_dtype = weights_dtype ,
764764 )
765765
766- self .motion_encoder = FlaxWanAnimateMotionEncoder (
766+ self .motion_encoder = WanAnimateMotionEncoder (
767767 rngs = rngs ,
768768 size = motion_encoder_size ,
769769 style_dim = motion_style_dim ,
@@ -773,7 +773,7 @@ def __init__(
773773 dtype = dtype ,
774774 weights_dtype = weights_dtype ,
775775 )
776- self .face_encoder = FlaxWanAnimateFaceEncoder (
776+ self .face_encoder = WanAnimateFaceEncoder (
777777 rngs = rngs ,
778778 in_dim = motion_encoder_dim ,
779779 out_dim = inner_dim ,
@@ -840,7 +840,7 @@ def init_block(rngs):
840840 face_adapters = []
841841 num_face_adapters = math .ceil (num_layers / inject_face_latents_blocks )
842842 for _ in range (num_face_adapters ):
843- fa = FlaxWanAnimateFaceBlockCrossAttention (
843+ fa = WanAnimateFaceBlockCrossAttention (
844844 rngs = rngs ,
845845 dim = inner_dim ,
846846 heads = num_attention_heads ,
@@ -1081,3 +1081,4 @@ def layer_forward(hidden_states):
10811081 if not return_dict :
10821082 return (hidden_states ,)
10831083 return {"sample" : hidden_states }
1084+
0 commit comments