Skip to content

Commit e7c5527

Browse files
committed
fix for step index error
1 parent 649e8c1 commit e7c5527

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,12 @@ def __call__(
205205
# This call initializes the internal state arrays
206206
step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
207207
max_logging.log(f"[DEBUG] scheduler.step output type: {type(step_output)}")
208-
max_logging.log(f"[DEBUG] scheduler.step output value: {step_output}")
209-
_, scheduler_state = step_output
208+
scheduler_state = step_output.state
210209
max_logging.log(f"[DEBUG] After prime step: scheduler_state type: {type(scheduler_state)}")
211-
max_logging.log(f"[DEBUG] After prime step: scheduler_state value: {scheduler_state}")
212-
max_logging.log(f"[DEBUG] Scheduler state primed: step_index={scheduler_state.step_index is not None}, last_sample={scheduler_state.last_sample is not None}")
213-
210+
if hasattr(scheduler_state, 'step_index'):
211+
max_logging.log(f"[DEBUG] Scheduler state primed: step_index={scheduler_state.step_index is not None}, last_sample={scheduler_state.last_sample is not None}")
212+
else:
213+
max_logging.log("[DEBUG] ERROR: scheduler_state object does not have expected attributes after priming.")
214214
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
215215
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
216216
latents = jax.device_put(latents, data_sharding)

0 commit comments

Comments
 (0)