3838from jax .sharding import PartitionSpec
3939from jax .lax import with_sharding_constraint
4040
41+
4142def _update_cache (cache , idx , value ):
4243 if cache is None :
4344 return None
44- return cache [:idx ] + (value ,) + cache [idx + 1 :]
45+ return cache [:idx ] + (value ,) + cache [idx + 1 :]
46+
4547
4648# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
4749def _canonicalize_tuple (x : Union [int , Sequence [int ]], rank : int , name : str ) -> Tuple [int , ...]:
@@ -55,11 +57,14 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T
5557
5658
5759class RepSentinel :
60+
5861 def __eq__ (self , other ):
5962 return isinstance (other , RepSentinel )
6063
64+
6165tree_util .register_pytree_node (RepSentinel , lambda x : ((), None ), lambda _ , __ : RepSentinel ())
6266
67+
6368class WanCausalConv3d (nnx .Module ):
6469
6570 def __init__ (
@@ -503,7 +508,6 @@ def __init__(
503508 )
504509
505510 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 ):
506-
507511 identity = x
508512 batch_size , time , height , width , channels = x .shape
509513
@@ -949,9 +953,11 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
949953class WanDiagonalGaussianDistribution (FlaxDiagonalGaussianDistribution ):
950954 pass
951955
956+
952957def _wan_diag_gauss_dist_flatten (dist ):
953958 return (dist .mean , dist .logvar , dist .std , dist .var ), (dist .deterministic ,)
954959
960+
955961def _wan_diag_gauss_dist_unflatten (aux , children ):
956962 mean , logvar , std , var = children
957963 deterministic = aux [0 ]
@@ -963,6 +969,7 @@ def _wan_diag_gauss_dist_unflatten(aux, children):
963969 obj .deterministic = deterministic
964970 return obj
965971
972+
966973tree_util .register_pytree_node (
967974 WanDiagonalGaussianDistribution ,
968975 _wan_diag_gauss_dist_flatten ,
@@ -993,9 +1000,11 @@ def init_cache(self):
9931000 # cache encode
9941001 self ._enc_feat_map = (None ,) * self ._enc_conv_num
9951002
1003+
9961004def _wan_cache_flatten (cache ):
9971005 return (cache ._feat_map , cache ._enc_feat_map ), (cache ._conv_num , cache ._enc_conv_num )
9981006
1007+
9991008def _wan_cache_unflatten (aux , children ):
10001009 conv_num , enc_conv_num = aux
10011010 feat_map , enc_feat_map = children
@@ -1009,9 +1018,10 @@ def _wan_cache_unflatten(aux, children):
10091018 obj ._enc_conv_num = enc_conv_num
10101019 obj ._feat_map = feat_map
10111020 obj ._enc_feat_map = enc_feat_map
1012- obj .module = None # module is not needed inside the trace for the cache logic now
1021+ obj .module = None # module is not needed inside the trace for the cache logic now
10131022 return obj
10141023
1024+
10151025tree_util .register_pytree_node (AutoencoderKLWanCache , _wan_cache_flatten , _wan_cache_unflatten )
10161026
10171027
@@ -1147,10 +1157,10 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11471157 feat_idx = enc_conv_idx ,
11481158 )
11491159 out = jnp .concatenate ([out , out_ ], axis = 1 )
1150-
1160+
11511161 # Update back to the wrapper object if needed, but for result we use local vars
11521162 feat_cache ._enc_feat_map = enc_feat_map
1153-
1163+
11541164 enc = self .quant_conv (out )
11551165 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
11561166 enc = jnp .concatenate ([mu , logvar ], axis = - 1 )
@@ -1173,7 +1183,7 @@ def _decode(
11731183 feat_cache .init_cache ()
11741184 iter_ = z .shape [1 ]
11751185 x = self .post_quant_conv (z )
1176-
1186+
11771187 dec_feat_map = feat_cache ._feat_map
11781188
11791189 for i in range (iter_ ):
@@ -1199,7 +1209,7 @@ def _decode(
11991209 fm3 = jnp .expand_dims (fm3 , axis = axis )
12001210 fm4 = jnp .expand_dims (fm4 , axis = axis )
12011211 out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1202-
1212+
12031213 feat_cache ._feat_map = dec_feat_map
12041214
12051215 out = jnp .clip (out , min = - 1.0 , max = 1.0 )
0 commit comments