Skip to content

Commit 6d38ae2

Browse files
committed
CFG gating
1 parent e79b523 commit 6d38ae2

4 files changed

Lines changed: 322 additions & 85 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ max_sequence_length: 1024
2525
sampler: "from_checkpoint"
2626

2727
# Generation parameters (aligned with Diffusers LTX-2.3 docs: use_cross_timestep, modality + audio CFG)
28+
# CFG negative-prompt encoding runs when guidance_scale>1 OR audio_guidance_scale>1 (Diffusers parity).
29+
# Modality isolation stacks with CFG even when stg_scale is 0 (pipeline stack_kind cfg_mod).
2830
global_batch_size_to_train_on: 1
2931
num_inference_steps: 30
3032
guidance_scale: 3.0

src/maxdiffusion/generate_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
106106
modality_scale=getattr(config, "modality_scale", 1.0),
107107
audio_modality_scale=getattr(config, "audio_modality_scale", None),
108108
use_cross_timestep=getattr(config, "use_cross_timestep", None),
109-
noise_scale=getattr(config, "noise_scale", 1.0),
109+
noise_scale=getattr(config, "noise_scale", 0.0),
110110
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
111111
)
112112
return out

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,7 @@ def scan_fn(carry, block_and_mask):
12361236
a2v_cross_attention_mask=encoder_attention_mask,
12371237
v2a_cross_attention_mask=audio_encoder_attention_mask,
12381238
perturbation_mask=mask,
1239+
modality_mask=modality_mask,
12391240
)
12401241

12411242
# 6. Output layers

0 commit comments

Comments
 (0)