@@ -199,18 +199,11 @@ def __call__(
199199 )
200200
201201 if self .scheduler_state .last_sample is None or self .scheduler_state .step_index is None :
202- max_logging .log ("[DEBUG] Priming scheduler state..." )
203202 t0 = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[0 ]
204203 dummy_noise = jnp .zeros_like (latents )
205204 # This call initializes the internal state arrays
206205 step_output = self .scheduler .step (scheduler_state , dummy_noise , t0 , latents )
207- max_logging .log (f"[DEBUG] scheduler.step output type: { type (step_output )} " )
208206 scheduler_state = step_output .state
209- max_logging .log (f"[DEBUG] After prime step: scheduler_state type: { type (scheduler_state )} " )
210- if hasattr (scheduler_state , 'step_index' ):
211- max_logging .log (f"[DEBUG] Scheduler state primed: step_index={ scheduler_state .step_index is not None } , last_sample={ scheduler_state .last_sample is not None } " )
212- else :
213- max_logging .log ("[DEBUG] ERROR: scheduler_state object does not have expected attributes after priming." )
214207 graphdef , state , rest_of_state = nnx .split (self .transformer , nnx .Param , ...)
215208 data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
216209 latents = jax .device_put (latents , data_sharding )
@@ -278,6 +271,7 @@ def run_inference_2_1_i2v(
278271
279272 def loop_body (step , vals ):
280273 latents , scheduler_state , rng = vals
274+ original_dtype = latents .dtype
281275 rng , timestep_rng = jax .random .split (rng )
282276 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
283277
@@ -302,6 +296,7 @@ def loop_body(step, vals):
302296 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
303297
304298 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
299+ latents = latents .astype (original_dtype )
305300 return latents , scheduler_state , rng
306301
307302 latents , _ , _ = jax .lax .fori_loop (0 , num_inference_steps , loop_body , (latents , scheduler_state , rng ))
0 commit comments