@@ -740,31 +740,45 @@ def __call__(
740740 out_channels = model_config ["in_channels" ] // math .prod (self .patchifier .patch_size ),
741741 )
742742 if output_type != "latent" :
743- if self .vae .decoder .timestep_conditioning :
744- noise = torch .randn_like (latents )
745- if not isinstance (decode_timestep , list ):
746- decode_timestep = [decode_timestep ] * latents .shape [0 ]
747- if decode_noise_scale is None :
748- decode_noise_scale = decode_timestep
749- elif not isinstance (decode_noise_scale , list ):
750- decode_noise_scale = [decode_noise_scale ] * latents .shape [0 ]
751-
752- decode_timestep = torch .tensor (decode_timestep ).to (latents .device )
753- decode_noise_scale = torch .tensor (decode_noise_scale ).to (latents .device )[:, None , None , None , None ]
754- latents = latents * (1 - decode_noise_scale ) + noise * decode_noise_scale
755- else :
756- decode_timestep = None
757- image = vae_decode (
758- latents ,
759- self .vae ,
760- is_video ,
761- vae_per_channel_normalize = kwargs .get ("vae_per_channel_normalize" , True ),
762- timestep = decode_timestep ,
763- )
764- image = self .image_processor .postprocess (torch .from_numpy (np .array (image .astype (jnp .float16 ))), output_type = output_type )
743+ if self .vae .decoder .timestep_conditioning :
744+ noise = jax .random .normal (jax .random .PRNGKey (5 ), latents .shape , dtype = latents .dtype ) #move the key to outer layer
745+
746+ # Convert decode_timestep to a list if it's not already one
747+ if not isinstance (decode_timestep , (list , jnp .ndarray )):
748+ decode_timestep = [decode_timestep ] * latents .shape [0 ]
749+
750+ # Handle decode_noise_scale
751+ if decode_noise_scale is None :
752+ decode_noise_scale = decode_timestep
753+ elif not isinstance (decode_noise_scale , (list , jnp .ndarray )):
754+ decode_noise_scale = [decode_noise_scale ] * latents .shape [0 ]
755+
756+ # Convert lists to JAX arrays
757+ decode_timestep = jnp .array (decode_timestep , dtype = jnp .float32 )
758+
759+ # Reshape decode_noise_scale for broadcasting
760+ decode_noise_scale = jnp .array (decode_noise_scale , dtype = jnp .float32 )
761+ decode_noise_scale = jnp .reshape (decode_noise_scale , (latents .shape [0 ],) + (1 ,) * (latents .ndim - 1 ))
762+
763+ # Apply the noise and scale
764+ latents = (
765+ latents * (1 - decode_noise_scale ) +
766+ noise * decode_noise_scale
767+ )
768+ else :
769+ decode_timestep = None
770+ image = self .vae .decode (
771+ latents = jax .device_put (latents , jax .devices ('tpu' )[0 ]), #.astype(jnp.bfloat16), #jax.device_put(latents, jax.devices('cpu')[0]),
772+ is_video = is_video ,
773+ vae_per_channel_normalize = kwargs .get (
774+ "vae_per_channel_normalize" , True ),
775+ timestep = decode_timestep #.astype(jnp.bfloat16),
776+ )
777+ image = self .postprocess_to_output_type ( #swap this out!
778+ torch .from_numpy (np .asarray (image .astype (jnp .float16 ))), output_type = output_type )
765779
766780 else :
767- image = latents
781+ image = latents
768782
769783 # Offload all models
770784
0 commit comments