@@ -48,6 +48,19 @@ class LTX2PipelineOutput:
4848 frames : jax .Array
4949 audio : Optional [jax .Array ] = None
5050
51+ def rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
52+ """
53+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure.
54+ Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
55+ """
56+ std_text = jnp .std (noise_pred_text , axis = list (range (1 , noise_pred_text .ndim )), keepdims = True )
57+ std_cfg = jnp .std (noise_cfg , axis = list (range (1 , noise_cfg .ndim )), keepdims = True )
58+ # rescale the results from guidance (fixes overexposure)
59+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg )
60+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
61+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale ) * noise_cfg
62+ return noise_cfg
63+
5164logger = logging .get_logger (__name__ )
5265
5366
@@ -741,18 +754,25 @@ def __call__(
741754 num_frames : int = 121 ,
742755 frame_rate : float = 24.0 ,
743756 num_inference_steps : int = 40 ,
757+ sigmas : Optional [List [float ]] = None ,
758+ timesteps : List [int ] = None ,
744759 guidance_scale : float = 3.0 ,
760+ guidance_rescale : float = 0.0 ,
745761 noise_scale : float = 1.0 ,
746762 num_videos_per_prompt : Optional [int ] = 1 ,
747763 generator : Optional [jax .Array ] = None ,
748764 latents : Optional [jax .Array ] = None ,
765+ audio_latents : Optional [jax .Array ] = None ,
749766 prompt_embeds : Optional [jax .Array ] = None ,
750767 negative_prompt_embeds : Optional [jax .Array ] = None ,
751768 prompt_attention_mask : Optional [jax .Array ] = None ,
752769 negative_prompt_attention_mask : Optional [jax .Array ] = None ,
770+ decode_timestep : Union [float , List [float ]] = 0.0 ,
771+ decode_noise_scale : Optional [Union [float , List [float ]]] = None ,
753772 max_sequence_length : int = 1024 ,
754773 dtype : Optional [jnp .dtype ] = jnp .float32 ,
755774 output_type : str = "pil" ,
775+ return_dict : bool = True ,
756776 ):
757777 # 1. Check inputs
758778 self .check_inputs (prompt , height , width , prompt_embeds , negative_prompt_embeds , prompt_attention_mask , negative_prompt_attention_mask )
@@ -842,7 +862,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
842862 connectors_graphdef , connectors_state , prompt_embeds_jax , prompt_attention_mask_jax .astype (jnp .bool_ )
843863 )
844864
845- for i , t in enumerate (tqdm ( timesteps ) ):
865+ for i , t in enumerate (timesteps ):
846866 noise_pred , noise_pred_audio = transformer_forward_pass (
847867 graphdef , state ,
848868 latents_jax ,
0 commit comments