Skip to content

Commit 4bad196

Browse files
committed
save
1 parent fefe18e commit 4bad196

1 file changed

Lines changed: 37 additions & 23 deletions

File tree

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -740,31 +740,45 @@ def __call__(
740740
out_channels=model_config["in_channels"] // math.prod(self.patchifier.patch_size),
741741
)
742742
if output_type != "latent":
743-
if self.vae.decoder.timestep_conditioning:
744-
noise = torch.randn_like(latents)
745-
if not isinstance(decode_timestep, list):
746-
decode_timestep = [decode_timestep] * latents.shape[0]
747-
if decode_noise_scale is None:
748-
decode_noise_scale = decode_timestep
749-
elif not isinstance(decode_noise_scale, list):
750-
decode_noise_scale = [decode_noise_scale] * latents.shape[0]
751-
752-
decode_timestep = torch.tensor(decode_timestep).to(latents.device)
753-
decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[:, None, None, None, None]
754-
latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale
755-
else:
756-
decode_timestep = None
757-
image = vae_decode(
758-
latents,
759-
self.vae,
760-
is_video,
761-
vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True),
762-
timestep=decode_timestep,
763-
)
764-
image = self.image_processor.postprocess(torch.from_numpy(np.array(image.astype(jnp.float16))), output_type=output_type)
743+
if self.vae.decoder.timestep_conditioning:
744+
noise = jax.random.normal(jax.random.PRNGKey(5), latents.shape, dtype=latents.dtype) #move the key to outer layer
745+
746+
# Convert decode_timestep to a list if it's not already one
747+
if not isinstance(decode_timestep, (list, jnp.ndarray)):
748+
decode_timestep = [decode_timestep] * latents.shape[0]
749+
750+
# Handle decode_noise_scale
751+
if decode_noise_scale is None:
752+
decode_noise_scale = decode_timestep
753+
elif not isinstance(decode_noise_scale, (list, jnp.ndarray)):
754+
decode_noise_scale = [decode_noise_scale] * latents.shape[0]
755+
756+
# Convert lists to JAX arrays
757+
decode_timestep = jnp.array(decode_timestep, dtype=jnp.float32)
758+
759+
# Reshape decode_noise_scale for broadcasting
760+
decode_noise_scale = jnp.array(decode_noise_scale, dtype=jnp.float32)
761+
decode_noise_scale = jnp.reshape(decode_noise_scale, (latents.shape[0],) + (1,) * (latents.ndim - 1))
762+
763+
# Apply the noise and scale
764+
latents = (
765+
latents * (1 - decode_noise_scale) +
766+
noise * decode_noise_scale
767+
)
768+
else:
769+
decode_timestep = None
770+
image = self.vae.decode(
771+
latents = jax.device_put(latents, jax.devices('tpu')[0]), #.astype(jnp.bfloat16), #jax.device_put(latents, jax.devices('cpu')[0]),
772+
is_video = is_video,
773+
vae_per_channel_normalize=kwargs.get(
774+
"vae_per_channel_normalize", True),
775+
timestep=decode_timestep #.astype(jnp.bfloat16),
776+
)
777+
image = self.postprocess_to_output_type( #swap this out!
778+
torch.from_numpy(np.asarray(image.astype(jnp.float16))), output_type=output_type)
765779

766780
else:
767-
image = latents
781+
image = latents
768782

769783
# Offload all models
770784

0 commit comments

Comments
 (0)