Skip to content

Commit 0b67a19

Browse files
committed
changed upsampler
1 parent 3e6499c commit 0b67a19

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def __call__(
762762
vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True),
763763
timestep=decode_timestep,
764764
)
765-
image = self.image_processor.postprocess(image, output_type=output_type)
765+
image = self.image_processor.postprocess(torch.from_numpy(np.array(image.astype(jnp.float16))), output_type=output_type)
766766

767767
else:
768768
image = latents
@@ -983,9 +983,13 @@ def __call__(self, height, width, num_frames, output_type, generator, config) ->
983983
skip_block_list=config.first_pass["skip_block_list"],
984984
)
985985
latents = result
986-
upsampled_latents = self._upsample_latents(latent_upsampler, latents)
987-
upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents)
988-
986+
upsampled_latents = self._upsample_latents(latent_upsampler, latents) #convert back to pytorch here
987+
988+
latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu'))
989+
upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu'))
990+
upsampled_latents = adain_filter_latent(
991+
latents=upsampled_latents, reference_latents=latents
992+
)
989993
latents = upsampled_latents
990994
output_type = original_output_type
991995

0 commit comments

Comments
 (0)