Skip to content

Commit dff5c30

Browse files
committed
Merge remote-tracking branch 'origin/main' into elisatsai_ring_attention
# Conflicts: # src/maxdiffusion/pipelines/wan/wan_pipeline.py # src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
2 parents 438fefd + ddec8d9 commit dff5c30

35 files changed

Lines changed: 7699 additions & 107 deletions

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
WAN2_1 = "wan2.1"
5353
WAN2_2 = "wan2.2"
54+
LTX2_VIDEO = "ltx2_video"
5455

5556
WAN_MODEL = WAN2_1
5657

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ num_frames: 81
280280
guidance_scale: 5.0
281281
flow_shift: 3.0
282282

283+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
284+
use_cfg_cache: False
285+
283286
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
284287
guidance_rescale: 0.0
285288
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ guidance_scale_high: 4.0
303303
# timestep to switch between low noise and high noise transformer
304304
boundary_ratio: 0.875
305305

306+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
307+
use_cfg_cache: False
308+
306309
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
307310
guidance_rescale: 0.0
308311
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ num_frames: 81
286286
guidance_scale: 5.0
287287
flow_shift: 5.0
288288

289+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
290+
use_cfg_cache: False
291+
289292
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
290293
guidance_rescale: 0.0
291294
num_inference_steps: 50

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ guidance_scale_high: 4.0
298298
# timestep to switch between low noise and high noise transformer
299299
boundary_ratio: 0.875
300300

301+
# Diffusion CFG cache (FasterCache-style)
302+
use_cfg_cache: False
303+
301304
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
302305
guidance_rescale: 0.0
303306
num_inference_steps: 50
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
attention_sharding_uniform: True
6+
precision: 'bf16'
7+
data_sharding: ['data', 'fsdp', 'context', 'tensor']
8+
remat_policy: "NONE"
9+
names_which_can_be_saved: []
10+
names_which_can_be_offloaded: []
11+
12+
jax_cache_dir: ''
13+
weights_dtype: 'bfloat16'
14+
activations_dtype: 'bfloat16'
15+
16+
run_name: ''
17+
output_dir: ''
18+
config_path: ''
19+
save_config_to_gcs: False
20+
21+
frame_rate: 30
22+
max_sequence_length: 1024
23+
sampler: "from_checkpoint"
24+
25+
# Generation parameters
26+
dataset_name: ''
27+
dataset_save_location: ''
28+
global_batch_size_to_train_on: 1
29+
num_inference_steps: 40
30+
guidance_scale: 3.0
31+
fps: 24
32+
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
33+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
34+
height: 512
35+
width: 768
36+
num_frames: 121
37+
decode_timestep: 0.05
38+
decode_noise_scale: 0.025
39+
quantization: "int8"
40+
seed: 10
41+
#parallelism
42+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
43+
logical_axis_rules: [
44+
['batch', 'data'],
45+
['activation_heads', 'fsdp'],
46+
['activation_batch', 'data'],
47+
['activation_kv', 'tensor'],
48+
['mlp','tensor'],
49+
['embed','fsdp'],
50+
['heads', 'tensor'],
51+
['norm', 'fsdp'],
52+
['conv_batch', ['data','fsdp']],
53+
['out_channels', 'tensor'],
54+
['conv_out', 'fsdp'],
55+
['conv_in', 'fsdp']
56+
]
57+
dcn_data_parallelism: 1
58+
dcn_fsdp_parallelism: -1
59+
dcn_context_parallelism: 1
60+
dcn_tensor_parallelism: 1
61+
ici_data_parallelism: 1
62+
ici_fsdp_parallelism: -1
63+
ici_context_parallelism: 1
64+
ici_tensor_parallelism: 1
65+
enable_profiler: False
66+
67+
replicate_vae: False
68+
69+
allow_split_physical_axes: False
70+
learning_rate_schedule_steps: -1
71+
max_train_steps: 500
72+
pretrained_model_name_or_path: 'Lightricks/LTX-2'
73+
model_name: "ltx2_video"
74+
model_type: "T2V"
75+
unet_checkpoint: ''
76+
checkpoint_dir: ""
77+
cache_latents_text_encoder_outputs: True
78+
per_device_batch_size: 1
79+
compile_topology_num_slices: -1
80+
quantization_local_shard_count: -1
81+
use_qwix_quantization: False
82+
weight_quantization_calibration_method: "absmax"
83+
act_quantization_calibration_method: "absmax"
84+
bwd_quantization_calibration_method: "absmax"
85+
qwix_module_path: ".*"
86+
jit_initializers: True
87+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps
123123
num_inference_steps=steps,
124124
guidance_scale_low=config.guidance_scale_low,
125125
guidance_scale_high=config.guidance_scale_high,
126+
use_cfg_cache=config.use_cfg_cache,
126127
)
127128
else:
128129
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
@@ -148,6 +149,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps
148149
num_inference_steps=steps,
149150
guidance_scale_low=config.guidance_scale_low,
150151
guidance_scale_high=config.guidance_scale_high,
152+
use_cfg_cache=config.use_cfg_cache,
151153
)
152154
else:
153155
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")

src/maxdiffusion/models/attention_flax.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,18 +1033,28 @@ def __init__(
10331033
)
10341034

10351035
def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]:
1036-
dtype = xq.dtype
1037-
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
1038-
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
1036+
# 1. Extract cos and sin, keeping them in native bfloat16
1037+
cos = jnp.real(freqs_cis).astype(xq.dtype)
1038+
sin = jnp.imag(freqs_cis).astype(xq.dtype)
10391039

1040-
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
1041-
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
1040+
# 2. Reshape the last dimension into pairs
1041+
xq_reshaped = xq.reshape(*xq.shape[:-1], -1, 2)
1042+
xk_reshaped = xk.reshape(*xk.shape[:-1], -1, 2)
10421043

1043-
xq_out_complex = xq_ * freqs_cis
1044-
xk_out_complex = xk_ * freqs_cis
1044+
# 3. Unbind the pairs
1045+
xq_0, xq_1 = xq_reshaped[..., 0], xq_reshaped[..., 1]
1046+
xk_0, xk_1 = xk_reshaped[..., 0], xk_reshaped[..., 1]
10451047

1046-
xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype)
1047-
xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype)
1048+
# 4. Pure real arithmetic (XLA will fuse these instantly into FMA instructions)
1049+
xq_out_0 = xq_0 * cos - xq_1 * sin
1050+
xq_out_1 = xq_0 * sin + xq_1 * cos
1051+
1052+
xk_out_0 = xk_0 * cos - xk_1 * sin
1053+
xk_out_1 = xk_0 * sin + xk_1 * cos
1054+
1055+
# 5. Stack and reshape back to original
1056+
xq_out = jnp.stack([xq_out_0, xq_out_1], axis=-1).reshape(xq.shape)
1057+
xk_out = jnp.stack([xk_out_0, xk_out_1], axis=-1).reshape(xk.shape)
10481058

10491059
return xq_out, xk_out
10501060

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,78 @@ def __call__(self, timestep, guidance, pooled_projection):
501501
conditioning = time_guidance_emb + pooled_projections
502502

503503
return conditioning
504+
505+
506+
class NNXTimesteps(nnx.Module):
507+
508+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
509+
self.num_channels = num_channels
510+
self.flip_sin_to_cos = flip_sin_to_cos
511+
self.downscale_freq_shift = downscale_freq_shift
512+
self.scale = scale
513+
514+
def __call__(self, timesteps: jax.Array) -> jax.Array:
515+
return get_sinusoidal_embeddings(
516+
timesteps=timesteps,
517+
embedding_dim=self.num_channels,
518+
freq_shift=self.downscale_freq_shift,
519+
flip_sin_to_cos=self.flip_sin_to_cos,
520+
scale=self.scale,
521+
)
522+
523+
524+
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
525+
526+
def __init__(
527+
self,
528+
rngs: nnx.Rngs,
529+
embedding_dim: int,
530+
size_emb_dim: int,
531+
use_additional_conditions: bool = False,
532+
dtype: jnp.dtype = jnp.float32,
533+
weights_dtype: jnp.dtype = jnp.float32,
534+
):
535+
self.outdim = size_emb_dim
536+
self.use_additional_conditions = use_additional_conditions
537+
538+
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
539+
self.timestep_embedder = NNXTimestepEmbedding(
540+
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
541+
)
542+
543+
if use_additional_conditions:
544+
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
545+
self.resolution_embedder = NNXTimestepEmbedding(
546+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
547+
)
548+
self.aspect_ratio_embedder = NNXTimestepEmbedding(
549+
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
550+
)
551+
552+
def __call__(
553+
self,
554+
timestep: jax.Array,
555+
resolution: Optional[jax.Array] = None,
556+
aspect_ratio: Optional[jax.Array] = None,
557+
hidden_dtype: jnp.dtype = jnp.float32,
558+
) -> jax.Array:
559+
timesteps_proj = self.time_proj(timestep)
560+
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
561+
562+
if self.use_additional_conditions:
563+
if resolution is None or aspect_ratio is None:
564+
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
565+
566+
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
567+
resolution_emb = self.resolution_embedder(resolution_emb)
568+
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
569+
570+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
571+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
572+
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
573+
574+
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
575+
else:
576+
conditioning = timesteps_emb
577+
578+
return conditioning
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

0 commit comments

Comments
 (0)