@@ -178,11 +178,6 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
178178 self .dim = dim
179179 self .mode = mode
180180 self .time_conv = None
181-
182- # We unpack nnx.Sequential to handle cache logic explicitly
183- self .upsample = None
184- self .conv = None
185- self .downsample_conv = None
186181
187182 if mode == "upsample2d" :
188183 self .upsample = WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" )
@@ -199,7 +194,7 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
199194
200195 def initialize_cache (self , batch_size , height , width , dtype ):
201196 cache = {}
202- if self . time_conv is not None :
197+ if hasattr ( self , "time_conv" ) :
203198 h_curr , w_curr = height , width
204199 if self .mode == "downsample3d" :
205200 # Resample (stride 2) happens before time conv
@@ -251,10 +246,12 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
251246 x , _ = self .downsample_conv (x , None )
252247 h_new , w_new , c_new = x .shape [1 :]
253248 x = x .reshape (b , t , h_new , w_new , c_new )
254-
249+
255250 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
256251 new_cache ["time_conv" ] = tc_cache
257-
252+ else :
253+ if hasattr (self , "resample" ):
254+ x , _ = self .resample (x , None )
258255 return x , new_cache
259256
260257
0 commit comments