@@ -756,7 +756,7 @@ def __init__(
756756 precision = precision ,
757757 )
758758
759- @nnx .jit
759+ @nnx .jit ( static_argnames = "feat_idx" )
760760 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 ):
761761 if feat_cache is not None :
762762 idx = feat_idx
@@ -787,7 +787,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
787787 feat_idx += 1
788788 else :
789789 x = self .conv_out (x )
790- return x , feat_cache , feat_idx
790+ return x , feat_cache , jnp . array ( feat_idx , dtype = jnp . int32 )
791791
792792
793793class WanDecoder3d (nnx .Module ):
@@ -905,7 +905,7 @@ def __init__(
905905 precision = precision ,
906906 )
907907
908- @nnx .jit
908+ @nnx .jit ( static_argnames = "feat_idx" )
909909 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 ):
910910 if feat_cache is not None :
911911 idx = feat_idx
@@ -939,7 +939,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
939939 feat_idx += 1
940940 else :
941941 x = self .conv_out (x )
942- return x , feat_cache , feat_idx
942+ return x , feat_cache , jnp . array ( feat_idx , dtype = jnp . int32 )
943943
944944
945945class WanDiagonalGaussianDistribution (FlaxDiagonalGaussianDistribution ):
0 commit comments