Skip to content

Commit 6650242

Browse files
committed
spatio temporal guidance
1 parent c7b5ac4 commit 6650242

3 files changed

Lines changed: 16 additions & 3 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ sampler: "from_checkpoint"
2828
global_batch_size_to_train_on: 1
2929
num_inference_steps: 40
3030
guidance_scale: 3.0
31-
stg_scale: 0.0
32-
spatio_temporal_guidance_blocks: []
31+
audio_guidance_scale: 7.0
32+
stg_scale: 1.0
33+
audio_stg_scale: 1.0
34+
modality_scale: 3.0
35+
audio_modality_scale: 3.0
36+
spatio_temporal_guidance_blocks: [28]
3337
fps: 24
3438
pipeline_type: multi-scale
3539
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."

src/maxdiffusion/generate_ltx2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9898
decode_timestep=getattr(config, "decode_timestep", 0.0),
9999
decode_noise_scale=getattr(config, "decode_noise_scale", None),
100100
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101+
audio_guidance_scale=getattr(config, "audio_guidance_scale", None),
102+
stg_scale=getattr(config, "stg_scale", 0.0),
103+
audio_stg_scale=getattr(config, "audio_stg_scale", None),
104+
modality_scale=getattr(config, "modality_scale", 1.0),
105+
audio_modality_scale=getattr(config, "audio_modality_scale", None),
101106
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
102107
)
103108
return out

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ def create_sharded_logical_transformer(
108108
tensors: dict = None,
109109
):
110110
def create_model(rngs: nnx.Rngs, ltx2_config: dict):
111-
transformer = LTX2VideoTransformer3DModel(**ltx2_config, rngs=rngs)
111+
transformer = LTX2VideoTransformer3DModel(
112+
**ltx2_config,
113+
spatio_temporal_guidance_blocks=tuple(getattr(config, "spatio_temporal_guidance_blocks", ())),
114+
rngs=rngs
115+
)
112116
return transformer
113117

114118
# 1. Load config.

0 commit comments

Comments
 (0)