Skip to content

Commit d243b48

Browse files
Merge pull request #341 from AI-Hypercomputer:prisha/ltx2_transformer
PiperOrigin-RevId: 881474617
2 parents 4afed9f + 672a817 commit d243b48

6 files changed

Lines changed: 1486 additions & 2 deletions

File tree

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
WAN2_1 = "wan2.1"
5353
WAN2_2 = "wan2.2"
54+
LTX2_VIDEO = "ltx2_video"
5455

5556
WAN_MODEL = WAN2_1
5657

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
attention_sharding_uniform: True
6+
precision: 'bf16'
7+
data_sharding: ['data', 'fsdp', 'context', 'tensor']
8+
remat_policy: "NONE"
9+
names_which_can_be_saved: []
10+
names_which_can_be_offloaded: []
11+
12+
jax_cache_dir: ''
13+
weights_dtype: 'bfloat16'
14+
activations_dtype: 'bfloat16'
15+
16+
run_name: ''
17+
output_dir: ''
18+
config_path: ''
19+
save_config_to_gcs: False
20+
21+
frame_rate: 30
22+
max_sequence_length: 1024
23+
sampler: "from_checkpoint"
24+
25+
# Generation parameters
26+
dataset_name: ''
27+
dataset_save_location: ''
28+
global_batch_size_to_train_on: 1
29+
num_inference_steps: 40
30+
guidance_scale: 3.0
31+
fps: 24
32+
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
33+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
34+
height: 512
35+
width: 768
36+
num_frames: 121
37+
decode_timestep: 0.05
38+
decode_noise_scale: 0.025
39+
quantization: "int8"
40+
seed: 10
41+
#parallelism
42+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
43+
logical_axis_rules: [
44+
['batch', 'data'],
45+
['activation_heads', 'fsdp'],
46+
['activation_batch', 'data'],
47+
['activation_kv', 'tensor'],
48+
['mlp','tensor'],
49+
['embed','fsdp'],
50+
['heads', 'tensor'],
51+
['norm', 'fsdp'],
52+
['conv_batch', ['data','fsdp']],
53+
['out_channels', 'tensor'],
54+
['conv_out', 'fsdp'],
55+
['conv_in', 'fsdp']
56+
]
57+
dcn_data_parallelism: 1
58+
dcn_fsdp_parallelism: -1
59+
dcn_context_parallelism: 1
60+
dcn_tensor_parallelism: 1
61+
ici_data_parallelism: 1
62+
ici_fsdp_parallelism: -1
63+
ici_context_parallelism: 1
64+
ici_tensor_parallelism: 1
65+
enable_profiler: False
66+
67+
replicate_vae: False
68+
69+
allow_split_physical_axes: False
70+
learning_rate_schedule_steps: -1
71+
max_train_steps: 500
72+
pretrained_model_name_or_path: 'Lightricks/LTX-2'
73+
model_name: "ltx2_video"
74+
model_type: "T2V"
75+
unet_checkpoint: ''
76+
checkpoint_dir: ""
77+
cache_latents_text_encoder_outputs: True
78+
per_device_batch_size: 1
79+
compile_topology_num_slices: -1
80+
quantization_local_shard_count: -1
81+
use_qwix_quantization: False
82+
weight_quantization_calibration_method: "absmax"
83+
act_quantization_calibration_method: "absmax"
84+
bwd_quantization_calibration_method: "absmax"
85+
qwix_module_path: ".*"
86+
jit_initializers: True
87+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,78 @@ def __call__(self, timestep, guidance, pooled_projection):
501501
conditioning = time_guidance_emb + pooled_projections
502502

503503
return conditioning
504+
505+
506+
class NNXTimesteps(nnx.Module):
507+
508+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
509+
self.num_channels = num_channels
510+
self.flip_sin_to_cos = flip_sin_to_cos
511+
self.downscale_freq_shift = downscale_freq_shift
512+
self.scale = scale
513+
514+
def __call__(self, timesteps: jax.Array) -> jax.Array:
515+
return get_sinusoidal_embeddings(
516+
timesteps=timesteps,
517+
embedding_dim=self.num_channels,
518+
freq_shift=self.downscale_freq_shift,
519+
flip_sin_to_cos=self.flip_sin_to_cos,
520+
scale=self.scale,
521+
)
522+
523+
524+
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
525+
526+
def __init__(
527+
self,
528+
rngs: nnx.Rngs,
529+
embedding_dim: int,
530+
size_emb_dim: int,
531+
use_additional_conditions: bool = False,
532+
dtype: jnp.dtype = jnp.float32,
533+
weights_dtype: jnp.dtype = jnp.float32,
534+
):
535+
self.outdim = size_emb_dim
536+
self.use_additional_conditions = use_additional_conditions
537+
538+
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
539+
self.timestep_embedder = NNXTimestepEmbedding(
540+
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
541+
)
542+
543+
if use_additional_conditions:
544+
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
545+
self.resolution_embedder = NNXTimestepEmbedding(
546+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
547+
)
548+
self.aspect_ratio_embedder = NNXTimestepEmbedding(
549+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
550+
)
551+
552+
def __call__(
553+
self,
554+
timestep: jax.Array,
555+
resolution: Optional[jax.Array] = None,
556+
aspect_ratio: Optional[jax.Array] = None,
557+
hidden_dtype: jnp.dtype = jnp.float32,
558+
) -> jax.Array:
559+
timesteps_proj = self.time_proj(timestep)
560+
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
561+
562+
if self.use_additional_conditions:
563+
if resolution is None or aspect_ratio is None:
564+
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
565+
566+
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
567+
resolution_emb = self.resolution_embedder(resolution_emb)
568+
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
569+
570+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
571+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
572+
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
573+
574+
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
575+
else:
576+
conditioning = timesteps_emb
577+
578+
return conditioning

0 commit comments

Comments
 (0)