Skip to content

Commit c4d072b

Browse files
committed
feat(ltx2): implement centralized and configuration-driven sharding strategy
1 parent 3ef0fdd commit c4d072b

7 files changed

Lines changed: 307 additions & 84 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import math
18-
from typing import Optional, Callable, Tuple
18+
from typing import Optional, Callable, Tuple, Any
1919
import flax.linen as nn
2020
from flax import nnx
2121
import jax
@@ -607,8 +607,6 @@ def wrap_ulysses_attention(query, key, value):
607607
orig_q_seq_len=query_seq_len,
608608
orig_kv_seq_len=key_seq_len,
609609
heads_per_tile=heads_per_tile,
610-
use_base2_exp=use_base2_exp,
611-
use_experimental_scheduler=use_experimental_scheduler,
612610
)
613611

614612
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
@@ -1106,9 +1104,16 @@ def __init__(
11061104
dtype: jnp.dtype = jnp.float32,
11071105
weights_dtype: jnp.dtype = jnp.float32,
11081106
precision: Optional[jax.lax.Precision] = None,
1107+
sharding_specs: Optional[Any] = None,
11091108
):
11101109
inner_dim = int(dim * mult)
11111110
dim_out = dim_out if dim_out is not None else dim
1111+
1112+
net_0_kernel = getattr(sharding_specs, "net_0_kernel", (None, "mlp"))
1113+
net_0_bias = getattr(sharding_specs, "net_0_bias", ("mlp",))
1114+
net_2_kernel = getattr(sharding_specs, "net_2_kernel", ("mlp", None))
1115+
net_2_bias = getattr(sharding_specs, "net_2_bias", (None,))
1116+
11121117
self.net_0 = nnx.Linear(
11131118
dim,
11141119
inner_dim,
@@ -1117,8 +1122,8 @@ def __init__(
11171122
dtype=dtype,
11181123
param_dtype=weights_dtype,
11191124
precision=precision,
1120-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
1121-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
1125+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_0_kernel),
1126+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_0_bias),
11221127
)
11231128
self.act = get_activation(activation_fn)
11241129
self.net_2 = nnx.Linear(
@@ -1129,8 +1134,8 @@ def __init__(
11291134
dtype=dtype,
11301135
param_dtype=weights_dtype,
11311136
precision=precision,
1132-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
1133-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
1137+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_2_kernel),
1138+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_2_bias),
11341139
)
11351140

11361141
def __call__(self, hidden_states: Array) -> Array:

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Optional
15+
from typing import Optional, Any
1616
import flax.linen as nn
1717
from flax import nnx
1818
import jax.numpy as jnp
@@ -60,6 +60,9 @@ def get_sinusoidal_embeddings(
6060
return signal
6161

6262

63+
64+
65+
6366
class NNXTimestepEmbedding(nnx.Module):
6467
r"""
6568
Time step Embedding Module. Learns embeddings for input time steps.
@@ -84,7 +87,12 @@ def __init__(
8487
dtype: jnp.dtype = jnp.float32,
8588
weights_dtype: jnp.dtype = jnp.float32,
8689
precision: jax.lax.Precision = None,
90+
sharding_specs: Optional[Any] = None,
8791
):
92+
linear_1_kernel = getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
93+
linear_1_bias = getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
94+
linear_2_kernel = getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
95+
linear_2_bias = getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
8896
self.linear_1 = nnx.Linear(
8997
rngs=rngs,
9098
in_features=in_channels,
@@ -95,12 +103,9 @@ def __init__(
95103
precision=precision,
96104
kernel_init=nnx.with_partitioning(
97105
nnx.initializers.xavier_uniform(),
98-
(
99-
"embed",
100-
"mlp",
101-
),
106+
linear_1_kernel,
102107
),
103-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
108+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
104109
)
105110

106111
if cond_proj_dim is not None:
@@ -127,12 +132,9 @@ def __init__(
127132
precision=precision,
128133
kernel_init=nnx.with_partitioning(
129134
nnx.initializers.xavier_uniform(),
130-
(
131-
"mlp",
132-
"embed",
133-
),
135+
linear_2_kernel,
134136
),
135-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
137+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
136138
)
137139

138140
if post_act_fn is None:
@@ -336,7 +338,12 @@ def __init__(
336338
dtype: jnp.dtype = jnp.float32,
337339
weights_dtype: jnp.dtype = jnp.float32,
338340
precision: jax.lax.Precision = None,
341+
sharding_specs: Optional[Any] = None,
339342
):
343+
linear_1_kernel = getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
344+
linear_1_bias = getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
345+
linear_2_kernel = getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
346+
linear_2_bias = getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
340347
if out_features is None:
341348
out_features = hidden_size
342349

@@ -350,12 +357,9 @@ def __init__(
350357
precision=precision,
351358
kernel_init=nnx.with_partitioning(
352359
nnx.initializers.xavier_uniform(),
353-
(
354-
"embed",
355-
"mlp",
356-
),
360+
linear_1_kernel,
357361
),
358-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
362+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
359363
)
360364
self.act_1 = get_activation(act_fn)
361365

@@ -369,12 +373,9 @@ def __init__(
369373
precision=precision,
370374
kernel_init=nnx.with_partitioning(
371375
nnx.initializers.xavier_uniform(),
372-
(
373-
"mlp",
374-
"embed",
375-
),
376+
linear_2_kernel,
376377
),
377-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
378+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
378379
)
379380

380381
def __call__(self, caption):
@@ -530,22 +531,38 @@ def __init__(
530531
use_additional_conditions: bool = False,
531532
dtype: jnp.dtype = jnp.float32,
532533
weights_dtype: jnp.dtype = jnp.float32,
534+
sharding_specs: Optional[Any] = None,
533535
):
534536
self.outdim = size_emb_dim
535537
self.use_additional_conditions = use_additional_conditions
536538

537539
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
538540
self.timestep_embedder = NNXTimestepEmbedding(
539-
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
541+
rngs=rngs,
542+
in_channels=256,
543+
time_embed_dim=embedding_dim,
544+
dtype=dtype,
545+
weights_dtype=weights_dtype,
546+
sharding_specs=sharding_specs,
540547
)
541548

542549
if use_additional_conditions:
543550
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
544551
self.resolution_embedder = NNXTimestepEmbedding(
545-
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
552+
rngs=rngs,
553+
in_channels=256,
554+
time_embed_dim=size_emb_dim,
555+
dtype=dtype,
556+
weights_dtype=weights_dtype,
557+
sharding_specs=sharding_specs,
546558
)
547559
self.aspect_ratio_embedder = NNXTimestepEmbedding(
548-
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
560+
rngs=rngs,
561+
in_channels=256,
562+
time_embed_dim=size_emb_dim,
563+
dtype=dtype,
564+
weights_dtype=weights_dtype,
565+
sharding_specs=sharding_specs,
549566
)
550567

551568
def __call__(

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ... import common_types
2222
from ..attention_flax import NNXAttentionOp
2323
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
24+
from .logical_sharding_ltx2 import get_sharding_specs, LTX2DiTShardingSpecs
2425

2526
Array = common_types.Array
2627
Mesh = common_types.Mesh
@@ -350,43 +351,32 @@ def __init__(
350351
rope_type: str = "interleaved",
351352
flash_block_sizes: BlockSizes = None,
352353
flash_min_seq_length: int = 4096,
353-
qkv_sharding_spec: Optional[tuple] = None,
354-
out_sharding_spec: Optional[tuple] = None,
355-
out_bias_sharding_spec: Optional[tuple] = None,
354+
sharding_specs: Optional[LTX2DiTShardingSpecs] = None,
356355
):
357356
self.heads = heads
358357
self.rope_type = rope_type
359358
self.dim_head = dim_head
360359
self.inner_dim = dim_head * heads
361360
self.dropout_rate = dropout
362361

363-
# Auto-detect hardware for sharding specs if not overridden
364-
tpu_type = get_tpu_type()
365-
is_ironwood = tpu_type == TpuType.TPU_7X
366-
367-
# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
368-
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
369-
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
370-
if qkv_sharding_spec is None:
371-
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
372-
if out_sharding_spec is None:
373-
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
374-
if out_bias_sharding_spec is None:
375-
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
362+
if sharding_specs is None:
363+
specs = get_sharding_specs("default", "ltx2_dit")
364+
else:
365+
specs = sharding_specs
376366

377367
# 1. Define Partitioned Initializers (Logical Axes)
378368
# Q, K, V kernels: [in_features (embed), out_features (heads)]
379-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
369+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.qkv_kernel)
380370
# Q, K, V biases: [out_features (heads)]
381-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
371+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.qkv_bias)
382372

383373
# Out kernel: [in_features (heads), out_features (embed)]
384-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
374+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.out_kernel)
385375
# Out bias: [out_features (embed)]
386-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
376+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.out_bias)
387377

388378
# Norm scales
389-
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
379+
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), specs.norm_scale)
390380

391381
# 2. Projections
392382
self.to_q = nnx.Linear(
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from dataclasses import dataclass
18+
from typing import Any, Optional
19+
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
20+
21+
22+
# --- Discrete Specs ---
23+
@dataclass
24+
class LTX2DiTShardingSpecs:
25+
"""Sharding specs for the LTX2 Diffusion Transformer."""
26+
27+
# --- Attention Layers (LTX2Attention) ---
28+
qkv_kernel: tuple
29+
out_kernel: tuple
30+
out_bias: tuple
31+
qkv_bias: tuple = ("heads",)
32+
33+
# --- Feed Forward Network (NNXSimpleFeedForward) ---
34+
net_0_kernel: tuple = (None, "mlp")
35+
net_0_bias: tuple = ("mlp",)
36+
net_2_kernel: tuple = ("mlp", None)
37+
net_2_bias: tuple = (None,)
38+
39+
# --- Input/Output Projections and Tables ---
40+
embed_kernel: tuple = (None, "embed")
41+
embed_bias: tuple = ("embed",)
42+
out_embed_kernel: tuple = ("embed", None)
43+
out_embed_bias: tuple = (None,)
44+
45+
# --- Shared Embeddings (NNXTimestepEmbedding, NNXPixArtAlphaTextProjection) ---
46+
emb_linear_1_kernel: tuple = ("embed", "mlp")
47+
emb_linear_1_bias: tuple = ("mlp",)
48+
emb_linear_2_kernel: tuple = ("mlp", "embed")
49+
emb_linear_2_bias: tuple = ("embed",)
50+
51+
# --- Normalization ---
52+
norm_scale: tuple = ("norm",)
53+
54+
55+
@dataclass
56+
class TextEncoderShardingSpecs:
57+
"""Specs for the Text Encoder execution."""
58+
59+
use_batched_text_encoder: bool = False
60+
text_encoder_kernel: Optional[tuple] = None
61+
62+
63+
@dataclass
64+
class VAEShardingSpecs:
65+
"""Sharding specs for the VAE."""
66+
67+
vae_conv_kernel: Optional[tuple] = None
68+
force_replication: bool = True
69+
70+
71+
# --- Unified Registry for LTX2 ---
72+
STRATEGIES = {
73+
"ironwood": {
74+
"ltx2_dit": LTX2DiTShardingSpecs(
75+
qkv_kernel=(None, "heads"),
76+
out_kernel=("heads", None),
77+
out_bias=(None,),
78+
),
79+
"text_encoder": TextEncoderShardingSpecs(
80+
use_batched_text_encoder=True,
81+
text_encoder_kernel=(None, "embed"),
82+
),
83+
"vae": VAEShardingSpecs(vae_conv_kernel=("batch", None, None, None)),
84+
},
85+
"trillium": {
86+
"ltx2_dit": LTX2DiTShardingSpecs(
87+
qkv_kernel=("embed", "heads"),
88+
out_kernel=("heads", "embed"),
89+
out_bias=("embed",),
90+
),
91+
"text_encoder": TextEncoderShardingSpecs(
92+
use_batched_text_encoder=False,
93+
text_encoder_kernel=(None, "embed"),
94+
),
95+
"vae": VAEShardingSpecs(vae_conv_kernel=(None, None, None, None)),
96+
},
97+
}
98+
99+
100+
def get_sharding_specs(strategy_name: str, component_name: str) -> Any:
101+
"""Unified factory to get specs for any component.
102+
103+
If strategy_name is 'default', it auto-detects the hardware.
104+
"""
105+
if strategy_name == "default":
106+
tpu_type = get_tpu_type()
107+
if tpu_type == TpuType.TPU_7X:
108+
strategy_name = "ironwood"
109+
else:
110+
strategy_name = "trillium"
111+
112+
hardware_profile = STRATEGIES.get(strategy_name, STRATEGIES["trillium"])
113+
specs = hardware_profile.get(component_name)
114+
if specs is None:
115+
raise ValueError(f"Component {component_name} not found in strategy {strategy_name}")
116+
return specs

0 commit comments

Comments
 (0)