Skip to content

Commit cbb8cfd

Browse files
committed
Add LTX2 Transformer integrated with Attention.
1 parent b7d9c55 commit cbb8cfd

7 files changed

Lines changed: 2828 additions & 162 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
attention_sharding_uniform: True
6+
7+
jax_cache_dir: ''
8+
weights_dtype: 'bfloat16'
9+
activations_dtype: 'bfloat16'
10+
11+
12+
run_name: ''
13+
output_dir: ''
14+
config_path: ''
15+
save_config_to_gcs: False
16+
17+
#Checkpoints
18+
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
19+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
20+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
21+
frame_rate: 30
22+
max_sequence_length: 512
23+
sampler: "from_checkpoint"
24+
25+
# Generation parameters
26+
pipeline_type: multi-scale
27+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
28+
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
29+
height: 512
30+
width: 512
31+
num_frames: 88
32+
flow_shift: 5.0
33+
downscale_factor: 0.6666666
34+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
35+
prompt_enhancement_words_threshold: 120
36+
stg_mode: "attention_values"
37+
decode_timestep: 0.05
38+
decode_noise_scale: 0.025
39+
seed: 10
40+
conditioning_media_paths: None #["IMAGE_PATH"]
41+
conditioning_start_frames: [0]
42+
43+
44+
first_pass:
45+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
46+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
47+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
48+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
49+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
50+
num_inference_steps: 30
51+
skip_final_inference_steps: 3
52+
skip_initial_inference_steps: 0
53+
cfg_star_rescale: True
54+
55+
second_pass:
56+
guidance_scale: [1]
57+
stg_scale: [1]
58+
rescaling_scale: [1]
59+
guidance_timesteps: [1.0]
60+
skip_block_list: [27]
61+
num_inference_steps: 30
62+
skip_initial_inference_steps: 17
63+
skip_final_inference_steps: 0
64+
cfg_star_rescale: True
65+
66+
#parallelism
67+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
68+
logical_axis_rules: [
69+
['batch', 'data'],
70+
['activation_heads', 'fsdp'],
71+
['activation_batch', 'data'],
72+
['activation_kv', 'tensor'],
73+
['mlp','tensor'],
74+
['embed','fsdp'],
75+
['heads', 'tensor'],
76+
['norm', 'fsdp'],
77+
['conv_batch', ['data','fsdp']],
78+
['out_channels', 'tensor'],
79+
['conv_out', 'fsdp'],
80+
['conv_in', 'fsdp']
81+
]
82+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
83+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
84+
dcn_fsdp_parallelism: -1
85+
dcn_context_parallelism: 1
86+
dcn_tensor_parallelism: 1
87+
ici_data_parallelism: 1
88+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
89+
ici_context_parallelism: 1
90+
ici_tensor_parallelism: 1
91+
92+
allow_split_physical_axes: False
93+
learning_rate_schedule_steps: -1
94+
max_train_steps: 500
95+
pretrained_model_name_or_path: ''
96+
unet_checkpoint: ''
97+
dataset_name: 'diffusers/pokemon-gpt4-captions'
98+
train_split: 'train'
99+
dataset_type: 'tf'
100+
cache_latents_text_encoder_outputs: True
101+
per_device_batch_size: 1
102+
compile_topology_num_slices: -1
103+
quantization_local_shard_count: -1
104+
use_qwix_quantization: False
105+
jit_initializers: True
106+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,80 @@ def __call__(self, timestep, guidance, pooled_projection):
501501
conditioning = time_guidance_emb + pooled_projections
502502

503503
return conditioning
504+
505+
506+
class NNXTimesteps(nnx.Module):
507+
508+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
509+
self.num_channels = num_channels
510+
self.flip_sin_to_cos = flip_sin_to_cos
511+
self.downscale_freq_shift = downscale_freq_shift
512+
self.scale = scale
513+
514+
def __call__(self, timesteps: jax.Array) -> jax.Array:
515+
return get_sinusoidal_embeddings(
516+
timesteps=timesteps,
517+
embedding_dim=self.num_channels,
518+
freq_shift=self.downscale_freq_shift,
519+
flip_sin_to_cos=self.flip_sin_to_cos,
520+
scale=self.scale,
521+
)
522+
523+
524+
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
525+
526+
def __init__(
527+
self,
528+
rngs: nnx.Rngs,
529+
embedding_dim: int,
530+
size_emb_dim: int,
531+
use_additional_conditions: bool = False,
532+
dtype: jnp.dtype = jnp.float32,
533+
weights_dtype: jnp.dtype = jnp.float32,
534+
):
535+
self.outdim = size_emb_dim
536+
self.use_additional_conditions = use_additional_conditions
537+
538+
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
539+
self.timestep_embedder = NNXTimestepEmbedding(
540+
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
541+
)
542+
543+
if use_additional_conditions:
544+
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
545+
self.resolution_embedder = NNXTimestepEmbedding(
546+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
547+
)
548+
self.aspect_ratio_embedder = NNXTimestepEmbedding(
549+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
550+
)
551+
552+
def __call__(
553+
self,
554+
timestep: jax.Array,
555+
resolution: Optional[jax.Array] = None,
556+
aspect_ratio: Optional[jax.Array] = None,
557+
hidden_dtype: jnp.dtype = jnp.float32,
558+
) -> jax.Array:
559+
timesteps_proj = self.time_proj(timestep)
560+
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
561+
562+
if self.use_additional_conditions:
563+
if resolution is None or aspect_ratio is None:
564+
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
565+
566+
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
567+
resolution_emb = self.resolution_embedder(resolution_emb)
568+
# Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
569+
# assuming resolution input was (batch_size, ...) so flatten logic holds.
570+
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
571+
572+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
573+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
574+
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
575+
576+
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
577+
else:
578+
conditioning = timesteps_emb
579+
580+
return conditioning

0 commit comments

Comments
 (0)