@@ -48,19 +48,14 @@ def _update_cache(cache, idx, value):
4848
4949
5050# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
51- def _canonicalize_tuple (
52- x : Union [int , Sequence [int ]], rank : int , name : str
53- ) -> Tuple [int , ...]:
51+ def _canonicalize_tuple (x : Union [int , Sequence [int ]], rank : int , name : str ) -> Tuple [int , ...]:
5452 """Canonicalizes a value to a tuple of integers."""
5553 if isinstance (x , int ):
5654 return (x ,) * rank
5755 elif isinstance (x , Sequence ) and len (x ) == rank :
5856 return tuple (x )
5957 else :
60- raise ValueError (
61- f"Argument '{ name } ' must be an integer or a sequence of { rank } "
62- f" integers. Got { x } "
63- )
58+ raise ValueError (f"Argument '{ name } ' must be an integer or a sequence of { rank } " f" integers. Got { x } " )
6459
6560
6661class RepSentinel :
@@ -69,9 +64,7 @@ def __eq__(self, other):
6964 return isinstance (other , RepSentinel )
7065
7166
72- tree_util .register_pytree_node (
73- RepSentinel , lambda x : ((), None ), lambda _ , __ : RepSentinel ()
74- )
67+ tree_util .register_pytree_node (RepSentinel , lambda x : ((), None ), lambda _ , __ : RepSentinel ())
7568
7669
7770class WanPatchify (nnx .Module ):
@@ -217,9 +210,7 @@ def __init__(
217210 self .bias = 0
218211
219212 def __call__ (self , x : jax .Array ) -> jax .Array :
220- normalized = jnp .linalg .norm (
221- x , ord = 2 , axis = (1 if self .channel_first else - 1 ), keepdims = True
222- )
213+ normalized = jnp .linalg .norm (x , ord = 2 , axis = (1 if self .channel_first else - 1 ), keepdims = True )
223214 normalized = x / jnp .maximum (normalized , self .eps )
224215 normalized = normalized * self .scale * self .gamma
225216 if self .bias :
@@ -229,9 +220,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
229220
230221class WanUpsample (nnx .Module ):
231222
232- def __init__ (
233- self , scale_factor : Tuple [float , float ], method : str = "nearest"
234- ):
223+ def __init__ (self , scale_factor : Tuple [float , float ], method : str = "nearest" ):
235224 # scale_factor for (H, W)
236225 # JAX resize works on spatial dims, H, W assuming (N, D, H, W, C) or (N, H, W, C)
237226 self .scale_factor = scale_factor
@@ -244,9 +233,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
244233 n , h , w , c = in_shape
245234 target_h = int (h * self .scale_factor [0 ])
246235 target_w = int (w * self .scale_factor [1 ])
247- out = jax .image .resize (
248- x .astype (jnp .float32 ), (n , target_h , target_w , c ), method = self .method
249- )
236+ out = jax .image .resize (x .astype (jnp .float32 ), (n , target_h , target_w , c ), method = self .method )
250237 return out .astype (input_dtype )
251238
252239
@@ -282,9 +269,7 @@ def __init__(
282269 use_bias = True ,
283270 padding = [(0 , 1 ), (0 , 1 )],
284271 rngs = rngs ,
285- kernel_init = nnx .with_partitioning (
286- nnx .initializers .xavier_uniform (), (None , None , None , None )
287- ),
272+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , None )),
288273 dtype = dtype ,
289274 param_dtype = weights_dtype ,
290275 precision = precision ,
@@ -409,11 +394,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
409394 feat_idx += 1
410395 else :
411396 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
412- if (
413- cache_x .shape [1 ] < 2
414- and feat_cache [idx ] is not None
415- and not isinstance (feat_cache [idx ], RepSentinel )
416- ):
397+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and not isinstance (feat_cache [idx ], RepSentinel ):
417398 # cache last frame of last two chunk
418399 cache_x = jnp .concatenate (
419400 [
@@ -422,14 +403,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
422403 ],
423404 axis = 1 ,
424405 )
425- if (
426- cache_x .shape [1 ] < 2
427- and feat_cache [idx ] is not None
428- and isinstance (feat_cache [idx ], RepSentinel )
429- ):
430- cache_x = jnp .concatenate (
431- [jnp .zeros (cache_x .shape ), cache_x ], axis = 1
432- )
406+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and isinstance (feat_cache [idx ], RepSentinel ):
407+ cache_x = jnp .concatenate ([jnp .zeros (cache_x .shape ), cache_x ], axis = 1 )
433408 if isinstance (feat_cache [idx ], RepSentinel ):
434409 x = self .time_conv (x )
435410 else :
@@ -453,9 +428,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
453428 feat_idx += 1
454429 else :
455430 cache_x = jnp .copy (x [:, - 1 :, :, :, :])
456- x = self .time_conv (
457- jnp .concatenate ([feat_cache [idx ][:, - 1 :, :, :, :], x ], axis = 1 )
458- )
431+ x = self .time_conv (jnp .concatenate ([feat_cache [idx ][:, - 1 :, :, :, :], x ], axis = 1 ))
459432 feat_cache = _update_cache (feat_cache , idx , cache_x )
460433 feat_idx += 1
461434
@@ -479,9 +452,7 @@ def __init__(
479452 self .nonlinearity = get_activation (non_linearity )
480453
481454 # layers
482- self .norm1 = WanRMS_norm (
483- dim = in_dim , rngs = rngs , images = False , channel_first = False
484- )
455+ self .norm1 = WanRMS_norm (dim = in_dim , rngs = rngs , images = False , channel_first = False )
485456 self .conv1 = WanCausalConv3d (
486457 rngs = rngs ,
487458 in_channels = in_dim ,
@@ -493,9 +464,7 @@ def __init__(
493464 weights_dtype = weights_dtype ,
494465 precision = precision ,
495466 )
496- self .norm2 = WanRMS_norm (
497- dim = out_dim , rngs = rngs , images = False , channel_first = False
498- )
467+ self .norm2 = WanRMS_norm (dim = out_dim , rngs = rngs , images = False , channel_first = False )
499468 self .conv2 = WanCausalConv3d (
500469 rngs = rngs ,
501470 in_channels = out_dim ,
@@ -581,9 +550,7 @@ def __init__(
581550 out_features = dim * 3 ,
582551 kernel_size = (1 , 1 ),
583552 rngs = rngs ,
584- kernel_init = nnx .with_partitioning (
585- nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )
586- ),
553+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )),
587554 dtype = dtype ,
588555 param_dtype = weights_dtype ,
589556 precision = precision ,
@@ -593,9 +560,7 @@ def __init__(
593560 out_features = dim ,
594561 kernel_size = (1 , 1 ),
595562 rngs = rngs ,
596- kernel_init = nnx .with_partitioning (
597- nnx .initializers .xavier_uniform (), (None , None , "conv_in" , None )
598- ),
563+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "conv_in" , None )),
599564 dtype = dtype ,
600565 param_dtype = weights_dtype ,
601566 precision = precision ,
@@ -709,9 +674,7 @@ def __init__(
709674 self .factor = self .factor_t * self .factor_s * self .factor_s
710675 self .group_size = in_channels * self .factor // out_channels
711676
712- def __call__ (
713- self , x : jax .Array , feat_cache = None , feat_idx = 0
714- ) -> Tuple [jax .Array , Any , int ]:
677+ def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 ) -> Tuple [jax .Array , Any , int ]:
715678 if self .factor_t > 1 or self .factor_s > 1 :
716679 n , d , h , w , c = x .shape
717680 pad_d = (self .factor_t - d % self .factor_t ) % self .factor_t
@@ -769,9 +732,7 @@ def __init__(
769732 self .out_channels = out_channels
770733 self .repeats = out_channels * self .factor // in_channels
771734
772- def __call__ (
773- self , x : jax .Array , feat_cache = None , feat_idx = 0 , first_chunk : bool = False
774- ) -> Tuple [jax .Array , Any , int ]:
735+ def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 , first_chunk : bool = False ) -> Tuple [jax .Array , Any , int ]:
775736 # Duplicate channels to match the expected total channels for upsampling.
776737 # x: (N, D, H, W, in_channels) -> (N, D, H, W, in_channels * self.repeats)
777738 x = jnp .repeat (x , repeats = self .repeats , axis = 4 )
@@ -891,9 +852,7 @@ def __call__(
891852
892853 x_shortcut = None
893854 if self .avg_shortcut is not None :
894- x_shortcut , feat_cache , feat_idx = self .avg_shortcut (
895- x_main , feat_cache , feat_idx
896- )
855+ x_shortcut , feat_cache , feat_idx = self .avg_shortcut (x_main , feat_cache , feat_idx )
897856 x = x + x_shortcut
898857
899858 if return_shortcut :
@@ -994,9 +953,7 @@ def __call__(
994953
995954 x_shortcut = None
996955 if self .avg_shortcut is not None :
997- x_shortcut , feat_cache , feat_idx = self .avg_shortcut (
998- x_main , feat_cache , feat_idx , first_chunk
999- )
956+ x_shortcut , feat_cache , feat_idx = self .avg_shortcut (x_main , feat_cache , feat_idx , first_chunk )
1000957 x = x + x_shortcut
1001958
1002959 if return_shortcut :
@@ -1052,9 +1009,7 @@ def __init__(
10521009 self .down_blocks = []
10531010 for i , (in_dim , out_dim ) in enumerate (zip (dims [:- 1 ], dims [1 :])):
10541011 if i != len (dim_mult ) - 1 :
1055- downsample_mode = (
1056- "downsample3d" if temperal_downsample [i ] else "downsample2d"
1057- )
1012+ downsample_mode = "downsample3d" if temperal_downsample [i ] else "downsample2d"
10581013 else :
10591014 downsample_mode = None
10601015 self .down_blocks .append (
@@ -1120,9 +1075,7 @@ def __init__(
11201075 )
11211076
11221077 # output blocks
1123- self .norm_out = WanRMS_norm (
1124- out_dim , channel_first = False , images = False , rngs = rngs
1125- )
1078+ self .norm_out = WanRMS_norm (out_dim , channel_first = False , images = False , rngs = rngs )
11261079 self .conv_out = WanCausalConv3d (
11271080 rngs = rngs ,
11281081 in_channels = out_dim ,
@@ -1281,9 +1234,7 @@ def __init__(
12811234 self .up_blocks = nnx .data (self .up_blocks )
12821235
12831236 # output blocks
1284- self .norm_out = WanRMS_norm (
1285- dim = out_dim , images = False , rngs = rngs , channel_first = False
1286- )
1237+ self .norm_out = WanRMS_norm (dim = out_dim , images = False , rngs = rngs , channel_first = False )
12871238 self .conv_out = WanCausalConv3d (
12881239 rngs = rngs ,
12891240 in_channels = out_dim ,
@@ -1297,9 +1248,7 @@ def __init__(
12971248 )
12981249
12991250 @nnx .jit (static_argnames = ("feat_idx" , "first_chunk" ))
1300- def __call__ (
1301- self , x : jax .Array , feat_cache = None , feat_idx = 0 , first_chunk : bool = False
1302- ):
1251+ def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 , first_chunk : bool = False ):
13031252 if feat_cache is not None :
13041253 idx = feat_idx
13051254 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
@@ -1553,9 +1502,7 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
15531502 if x .shape [- 1 ] != 3 :
15541503 # reshape channel last for JAX
15551504 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1556- assert (
1557- x .shape [- 1 ] == 3
1558- ), f"Expected input shape (N, D, H, W, 3), got { x .shape } "
1505+ assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
15591506
15601507 x = self .patchify (x )
15611508
@@ -1566,9 +1513,7 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
15661513 for i in range (iter_ ):
15671514 enc_conv_idx = 0
15681515 if i == 0 :
1569- out , enc_feat_map , enc_conv_idx = self .encoder (
1570- x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = enc_conv_idx
1571- )
1516+ out , enc_feat_map , enc_conv_idx = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = enc_conv_idx )
15721517 else :
15731518 out_ , enc_feat_map , enc_conv_idx = self .encoder (
15741519 x [:, 1 + 4 * (i - 1 ) : 1 + 4 * i , :, :, :],
@@ -1621,9 +1566,7 @@ def _decode(
16211566 first_chunk = True ,
16221567 )
16231568 else :
1624- out_ , dec_feat_map , conv_idx = self .decoder (
1625- x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx
1626- )
1569+ out_ , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
16271570 out = jnp .concatenate ([out , out_ ], axis = 1 )
16281571
16291572 feat_cache ._feat_map = dec_feat_map
@@ -1645,9 +1588,7 @@ def decode(
16451588 if z .shape [- 1 ] != self .z_dim :
16461589 # reshape channel last for JAX
16471590 z = jnp .transpose (z , (0 , 2 , 3 , 4 , 1 ))
1648- assert (
1649- z .shape [- 1 ] == self .z_dim
1650- ), f"Expected input shape (N, D, H, W, { self .z_dim } , got { z .shape } "
1591+ assert z .shape [- 1 ] == self .z_dim , f"Expected input shape (N, D, H, W, { self .z_dim } , got { z .shape } "
16511592 decoded = self ._decode (z , feat_cache ).sample
16521593 if not return_dict :
16531594 return (decoded ,)
0 commit comments