Skip to content

Commit a325a13

Browse files
committed
checking device for input and outputs of vae decode
1 parent e08d5af commit a325a13

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,10 +1380,14 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13801380
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
13811381

13821382
latents = latents.astype(self.vae.dtype)
1383+
max_logging.log(f"[Profiling] Latents device before VAE Decode: {latents.devices()}")
13831384
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
1385+
max_logging.log(f"[Profiling] Video device after VAE Decode: {getattr(video, 'devices', lambda: 'Unknown')()}")
13841386
else:
13851387
latents = latents.astype(self.vae.dtype)
1388+
max_logging.log(f"[Profiling] Latents device before VAE Decode (else): {latents.devices()}")
13861389
video = self.vae.decode(latents, return_dict=False)[0]
1390+
max_logging.log(f"[Profiling] Video device after VAE Decode (else): {getattr(video, 'devices', lambda: 'Unknown')()}")
13871391
t_vae = time.perf_counter() - s_vae
13881392
max_logging.log(f"[Tuning] VAE decoding took: {t_vae:.4f} seconds")
13891393
# Post-process video (converts to numpy/PIL)

0 commit comments

Comments
 (0)