@@ -225,7 +225,7 @@ def __init__(
225225 ):
226226 self .dim = dim
227227 self .mode = mode
228- self .time_conv = None
228+ self .time_conv = nnx . data ( None )
229229
230230 if mode == "upsample2d" :
231231 self .resample = nnx .Sequential (
@@ -554,8 +554,8 @@ def __init__(
554554 precision = precision ,
555555 )
556556 )
557- self .attentions = attentions
558- self .resnets = resnets
557+ self .attentions = nnx . data ( attentions )
558+ self .resnets = nnx . data ( resnets )
559559
560560 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
561561 x = self .resnets [0 ](x , feat_cache , feat_idx )
@@ -601,10 +601,10 @@ def __init__(
601601 )
602602 )
603603 current_dim = out_dim
604- self .resnets = resnets
604+ self .resnets = nnx . data ( resnets )
605605
606606 # Add upsampling layer if needed.
607- self .upsamplers = None
607+ self .upsamplers = nnx . data ( None )
608608 if upsample_mode is not None :
609609 self .upsamplers = [
610610 WanResample (
@@ -710,6 +710,7 @@ def __init__(
710710 )
711711 )
712712 scale /= 2.0
713+ self .down_blocks = nnx .data (self .down_blocks )
713714
714715 # middle_blocks
715716 self .mid_block = WanMidBlock (
@@ -873,6 +874,7 @@ def __init__(
873874 # Update scale for next iteration
874875 if upsample_mode is not None :
875876 scale *= 2.0
877+ self .up_blocks = nnx .data (self .up_blocks )
876878
877879 # output blocks
878880 self .norm_out = WanRMS_norm (dim = out_dim , images = False , rngs = rngs , channel_first = False )
0 commit comments