Skip to content

Commit fc67734

Browse files
committed
prompt embeds converted to bf16
1 parent 938eb82 commit fc67734

3 files changed

Lines changed: 4 additions & 2 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
2222
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
2323
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
2424
frame_rate: 30
25-
max_sequence_length: 512
25+
max_sequence_length: 1024
2626
sampler: "from_checkpoint"
2727

2828
# Generation parameters

src/maxdiffusion/generate_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9797
frame_rate=getattr(config, "fps", 24.0),
9898
decode_timestep=getattr(config, "decode_timestep", 0.0),
9999
decode_noise_scale=getattr(config, "decode_noise_scale", None),
100+
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101+
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
100102
)
101103
return out
102104

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def __call__(
10991099
decode_timestep: Union[float, List[float]] = 0.0,
11001100
decode_noise_scale: Optional[Union[float, List[float]]] = None,
11011101
max_sequence_length: int = 1024,
1102-
dtype: Optional[jnp.dtype] = jnp.float32,
1102+
dtype: Optional[jnp.dtype] = None,
11031103
output_type: str = "pil",
11041104
return_dict: bool = True,
11051105
):

0 commit comments

Comments
 (0)