Skip to content

Commit aedbf16

Browse files
committed
rescale noise cfg added
1 parent 4253c19 commit aedbf16

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ class LTX2PipelineOutput:
4848
frames: jax.Array
4949
audio: Optional[jax.Array] = None
5050

51+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
52+
"""
53+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure.
54+
Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
55+
"""
56+
std_text = jnp.std(noise_pred_text, axis=list(range(1, noise_pred_text.ndim)), keepdims=True)
57+
std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True)
58+
# rescale the results from guidance (fixes overexposure)
59+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
60+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
61+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
62+
return noise_cfg
63+
5164
logger = logging.get_logger(__name__)
5265

5366

@@ -741,18 +754,25 @@ def __call__(
741754
num_frames: int = 121,
742755
frame_rate: float = 24.0,
743756
num_inference_steps: int = 40,
757+
sigmas: Optional[List[float]] = None,
758+
timesteps: List[int] = None,
744759
guidance_scale: float = 3.0,
760+
guidance_rescale: float = 0.0,
745761
noise_scale: float = 1.0,
746762
num_videos_per_prompt: Optional[int] = 1,
747763
generator: Optional[jax.Array] = None,
748764
latents: Optional[jax.Array] = None,
765+
audio_latents: Optional[jax.Array] = None,
749766
prompt_embeds: Optional[jax.Array] = None,
750767
negative_prompt_embeds: Optional[jax.Array] = None,
751768
prompt_attention_mask: Optional[jax.Array] = None,
752769
negative_prompt_attention_mask: Optional[jax.Array] = None,
770+
decode_timestep: Union[float, List[float]] = 0.0,
771+
decode_noise_scale: Optional[Union[float, List[float]]] = None,
753772
max_sequence_length: int = 1024,
754773
dtype: Optional[jnp.dtype] = jnp.float32,
755774
output_type: str = "pil",
775+
return_dict: bool = True,
756776
):
757777
# 1. Check inputs
758778
self.check_inputs(prompt, height, width, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask)
@@ -842,7 +862,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
842862
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
843863
)
844864

845-
for i, t in enumerate(tqdm(timesteps)):
865+
for i, t in enumerate(timesteps):
846866
noise_pred, noise_pred_audio = transformer_forward_pass(
847867
graphdef, state,
848868
latents_jax,

0 commit comments

Comments
 (0)