Skip to content

Commit 5a71b21

Browse files
committed
scheduler changed
1 parent 72373a0 commit 5a71b21

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,12 @@ def __call__(
230230
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
231231
)
232232

233-
if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None:
234-
t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
235-
dummy_noise = jnp.zeros_like(latents)
236-
# This call initializes the internal state arrays
237-
step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
238-
scheduler_state = step_output.state
233+
# if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None:
234+
# t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
235+
# dummy_noise = jnp.zeros_like(latents)
236+
# # This call initializes the internal state arrays
237+
# step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
238+
# scheduler_state = step_output.state
239239
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
240240
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
241241
latents = jax.device_put(latents, data_sharding)

0 commit comments

Comments
 (0)