Skip to content

Commit b32cda7

Browse files
committed
dtype casting issue fixed
1 parent e7c5527 commit b32cda7

1 file changed

Lines changed: 2 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,11 @@ def __call__(
199199
)
200200

201201
if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None:
202-
max_logging.log("[DEBUG] Priming scheduler state...")
203202
t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
204203
dummy_noise = jnp.zeros_like(latents)
205204
# This call initializes the internal state arrays
206205
step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
207-
max_logging.log(f"[DEBUG] scheduler.step output type: {type(step_output)}")
208206
scheduler_state = step_output.state
209-
max_logging.log(f"[DEBUG] After prime step: scheduler_state type: {type(scheduler_state)}")
210-
if hasattr(scheduler_state, 'step_index'):
211-
max_logging.log(f"[DEBUG] Scheduler state primed: step_index={scheduler_state.step_index is not None}, last_sample={scheduler_state.last_sample is not None}")
212-
else:
213-
max_logging.log("[DEBUG] ERROR: scheduler_state object does not have expected attributes after priming.")
214207
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
215208
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
216209
latents = jax.device_put(latents, data_sharding)
@@ -278,6 +271,7 @@ def run_inference_2_1_i2v(
278271

279272
def loop_body(step, vals):
280273
latents, scheduler_state, rng = vals
274+
original_dtype = latents.dtype
281275
rng, timestep_rng = jax.random.split(rng)
282276
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
283277

@@ -302,6 +296,7 @@ def loop_body(step, vals):
302296
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
303297

304298
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
299+
latents = latents.astype(original_dtype)
305300
return latents, scheduler_state, rng
306301

307302
latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state, rng))

0 commit comments

Comments
 (0)