Skip to content

Commit 2b6a758

Browse files
committed
Refactor
1 parent dd27d04 commit 2b6a758

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def __init__(
756756
precision=precision,
757757
)
758758

759-
@nnx.jit
759+
@nnx.jit(static_argnames="feat_idx")
760760
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
761761
if feat_cache is not None:
762762
idx = feat_idx
@@ -787,7 +787,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
787787
feat_idx += 1
788788
else:
789789
x = self.conv_out(x)
790-
return x, feat_cache, feat_idx
790+
return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32)
791791

792792

793793
class WanDecoder3d(nnx.Module):
@@ -905,7 +905,7 @@ def __init__(
905905
precision=precision,
906906
)
907907

908-
@nnx.jit
908+
@nnx.jit(static_argnames="feat_idx")
909909
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
910910
if feat_cache is not None:
911911
idx = feat_idx
@@ -939,7 +939,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
939939
feat_idx += 1
940940
else:
941941
x = self.conv_out(x)
942-
return x, feat_cache, feat_idx
942+
return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32)
943943

944944

945945
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):

0 commit comments

Comments
 (0)