Skip to content

Commit 66401d7

Browse files
committed
Centralized sharding for LTX2
1 parent 3ef0fdd commit 66401d7

11 files changed

Lines changed: 366 additions & 88 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ logical_axis_rules: [
6464
]
6565
data_sharding: ['data', 'fsdp', 'context', 'tensor']
6666

67+
sharding:
68+
transformer: 'default'
69+
vae: 'default'
70+
text_encoder: 'default'
71+
text_connector: 'default'
72+
6773
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
6874
dcn_fsdp_parallelism: -1
6975

src/maxdiffusion/models/attention_flax.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import math
18-
from typing import Optional, Callable, Tuple
18+
from typing import Optional, Callable, Tuple, Any
1919
import flax.linen as nn
2020
from flax import nnx
2121
import jax
@@ -607,8 +607,6 @@ def wrap_ulysses_attention(query, key, value):
607607
orig_q_seq_len=query_seq_len,
608608
orig_kv_seq_len=key_seq_len,
609609
heads_per_tile=heads_per_tile,
610-
use_base2_exp=use_base2_exp,
611-
use_experimental_scheduler=use_experimental_scheduler,
612610
)
613611

614612
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
@@ -1106,9 +1104,16 @@ def __init__(
11061104
dtype: jnp.dtype = jnp.float32,
11071105
weights_dtype: jnp.dtype = jnp.float32,
11081106
precision: Optional[jax.lax.Precision] = None,
1107+
sharding_specs: Optional[Any] = None,
11091108
):
11101109
inner_dim = int(dim * mult)
11111110
dim_out = dim_out if dim_out is not None else dim
1111+
1112+
net_0_kernel = getattr(sharding_specs, "net_0_kernel", (None, "mlp"))
1113+
net_0_bias = getattr(sharding_specs, "net_0_bias", ("mlp",))
1114+
net_2_kernel = getattr(sharding_specs, "net_2_kernel", ("mlp", None))
1115+
net_2_bias = getattr(sharding_specs, "net_2_bias", (None,))
1116+
11121117
self.net_0 = nnx.Linear(
11131118
dim,
11141119
inner_dim,
@@ -1117,8 +1122,8 @@ def __init__(
11171122
dtype=dtype,
11181123
param_dtype=weights_dtype,
11191124
precision=precision,
1120-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
1121-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
1125+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_0_kernel),
1126+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_0_bias),
11221127
)
11231128
self.act = get_activation(activation_fn)
11241129
self.net_2 = nnx.Linear(
@@ -1129,8 +1134,8 @@ def __init__(
11291134
dtype=dtype,
11301135
param_dtype=weights_dtype,
11311136
precision=precision,
1132-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
1133-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
1137+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_2_kernel),
1138+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_2_bias),
11341139
)
11351140

11361141
def __call__(self, hidden_states: Array) -> Array:

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Optional
15+
from typing import Optional, Any
1616
import flax.linen as nn
1717
from flax import nnx
1818
import jax.numpy as jnp
@@ -84,7 +84,12 @@ def __init__(
8484
dtype: jnp.dtype = jnp.float32,
8585
weights_dtype: jnp.dtype = jnp.float32,
8686
precision: jax.lax.Precision = None,
87+
sharding_specs: Optional[Any] = None,
8788
):
89+
linear_1_kernel = getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
90+
linear_1_bias = getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
91+
linear_2_kernel = getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
92+
linear_2_bias = getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
8893
self.linear_1 = nnx.Linear(
8994
rngs=rngs,
9095
in_features=in_channels,
@@ -95,12 +100,9 @@ def __init__(
95100
precision=precision,
96101
kernel_init=nnx.with_partitioning(
97102
nnx.initializers.xavier_uniform(),
98-
(
99-
"embed",
100-
"mlp",
101-
),
103+
linear_1_kernel,
102104
),
103-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
105+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
104106
)
105107

106108
if cond_proj_dim is not None:
@@ -127,12 +129,9 @@ def __init__(
127129
precision=precision,
128130
kernel_init=nnx.with_partitioning(
129131
nnx.initializers.xavier_uniform(),
130-
(
131-
"mlp",
132-
"embed",
133-
),
132+
linear_2_kernel,
134133
),
135-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
134+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
136135
)
137136

138137
if post_act_fn is None:
@@ -336,7 +335,12 @@ def __init__(
336335
dtype: jnp.dtype = jnp.float32,
337336
weights_dtype: jnp.dtype = jnp.float32,
338337
precision: jax.lax.Precision = None,
338+
sharding_specs: Optional[Any] = None,
339339
):
340+
linear_1_kernel = getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
341+
linear_1_bias = getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
342+
linear_2_kernel = getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
343+
linear_2_bias = getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
340344
if out_features is None:
341345
out_features = hidden_size
342346

@@ -350,12 +354,9 @@ def __init__(
350354
precision=precision,
351355
kernel_init=nnx.with_partitioning(
352356
nnx.initializers.xavier_uniform(),
353-
(
354-
"embed",
355-
"mlp",
356-
),
357+
linear_1_kernel,
357358
),
358-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
359+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
359360
)
360361
self.act_1 = get_activation(act_fn)
361362

@@ -369,12 +370,9 @@ def __init__(
369370
precision=precision,
370371
kernel_init=nnx.with_partitioning(
371372
nnx.initializers.xavier_uniform(),
372-
(
373-
"mlp",
374-
"embed",
375-
),
373+
linear_2_kernel,
376374
),
377-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
375+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
378376
)
379377

380378
def __call__(self, caption):
@@ -530,22 +528,38 @@ def __init__(
530528
use_additional_conditions: bool = False,
531529
dtype: jnp.dtype = jnp.float32,
532530
weights_dtype: jnp.dtype = jnp.float32,
531+
sharding_specs: Optional[Any] = None,
533532
):
534533
self.outdim = size_emb_dim
535534
self.use_additional_conditions = use_additional_conditions
536535

537536
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
538537
self.timestep_embedder = NNXTimestepEmbedding(
539-
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
538+
rngs=rngs,
539+
in_channels=256,
540+
time_embed_dim=embedding_dim,
541+
dtype=dtype,
542+
weights_dtype=weights_dtype,
543+
sharding_specs=sharding_specs,
540544
)
541545

542546
if use_additional_conditions:
543547
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
544548
self.resolution_embedder = NNXTimestepEmbedding(
545-
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
549+
rngs=rngs,
550+
in_channels=256,
551+
time_embed_dim=size_emb_dim,
552+
dtype=dtype,
553+
weights_dtype=weights_dtype,
554+
sharding_specs=sharding_specs,
546555
)
547556
self.aspect_ratio_embedder = NNXTimestepEmbedding(
548-
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
557+
rngs=rngs,
558+
in_channels=256,
559+
time_embed_dim=size_emb_dim,
560+
dtype=dtype,
561+
weights_dtype=weights_dtype,
562+
sharding_specs=sharding_specs,
549563
)
550564

551565
def __call__(

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ... import common_types
2222
from ..attention_flax import NNXAttentionOp
2323
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
24+
from .logical_sharding_ltx2 import get_sharding_specs, LTX2DiTShardingSpecs
2425

2526
Array = common_types.Array
2627
Mesh = common_types.Mesh
@@ -350,43 +351,32 @@ def __init__(
350351
rope_type: str = "interleaved",
351352
flash_block_sizes: BlockSizes = None,
352353
flash_min_seq_length: int = 4096,
353-
qkv_sharding_spec: Optional[tuple] = None,
354-
out_sharding_spec: Optional[tuple] = None,
355-
out_bias_sharding_spec: Optional[tuple] = None,
354+
sharding_specs: Optional[LTX2DiTShardingSpecs] = None,
356355
):
357356
self.heads = heads
358357
self.rope_type = rope_type
359358
self.dim_head = dim_head
360359
self.inner_dim = dim_head * heads
361360
self.dropout_rate = dropout
362361

363-
# Auto-detect hardware for sharding specs if not overridden
364-
tpu_type = get_tpu_type()
365-
is_ironwood = tpu_type == TpuType.TPU_7X
366-
367-
# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
368-
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
369-
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
370-
if qkv_sharding_spec is None:
371-
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
372-
if out_sharding_spec is None:
373-
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
374-
if out_bias_sharding_spec is None:
375-
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
362+
if sharding_specs is None:
363+
specs = get_sharding_specs("default", "ltx2_dit")
364+
else:
365+
specs = sharding_specs
376366

377367
# 1. Define Partitioned Initializers (Logical Axes)
378368
# Q, K, V kernels: [in_features (embed), out_features (heads)]
379-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
369+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.qkv_kernel)
380370
# Q, K, V biases: [out_features (heads)]
381-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
371+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.qkv_bias)
382372

383373
# Out kernel: [in_features (heads), out_features (embed)]
384-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
374+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.out_kernel)
385375
# Out bias: [out_features (embed)]
386-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
376+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.out_bias)
387377

388378
# Norm scales
389-
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
379+
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), specs.norm_scale)
390380

391381
# 2. Projections
392382
self.to_q = nnx.Linear(

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple, Union, Optional, Sequence
15+
from typing import Tuple, Union, Optional, Sequence, Any
1616

1717
import jax
1818
import jax.numpy as jnp
@@ -584,6 +584,7 @@ def __init__(
584584
dtype: jnp.dtype = jnp.float32,
585585
weights_dtype: jnp.dtype = jnp.float32,
586586
precision: jax.lax.Precision = None,
587+
sharding_specs: Optional[Any] = None,
587588
):
588589
if timestep_conditioning:
589590
self.time_embedder = nnx.data(
@@ -594,6 +595,7 @@ def __init__(
594595
use_additional_conditions=False,
595596
dtype=dtype,
596597
weights_dtype=weights_dtype,
598+
sharding_specs=sharding_specs,
597599
)
598600
)
599601
else:
@@ -674,6 +676,7 @@ def __init__(
674676
dtype: jnp.dtype = jnp.float32,
675677
weights_dtype: jnp.dtype = jnp.float32,
676678
precision: jax.lax.Precision = None,
679+
sharding_specs: Optional[Any] = None,
677680
):
678681
out_channels = out_channels or in_channels
679682

@@ -687,6 +690,7 @@ def __init__(
687690
use_additional_conditions=False,
688691
dtype=dtype,
689692
weights_dtype=weights_dtype,
693+
sharding_specs=sharding_specs,
690694
)
691695
)
692696

@@ -960,6 +964,7 @@ def __init__(
960964
dtype: jnp.dtype = jnp.float32,
961965
weights_dtype: jnp.dtype = jnp.float32,
962966
precision: jax.lax.Precision = None,
967+
sharding_specs: Optional[Any] = None,
963968
):
964969
self.patch_size = patch_size
965970
self.patch_size_t = patch_size_t
@@ -999,6 +1004,7 @@ def __init__(
9991004
dtype=dtype,
10001005
weights_dtype=weights_dtype,
10011006
precision=precision,
1007+
sharding_specs=sharding_specs,
10021008
)
10031009

10041010
# up blocks
@@ -1026,6 +1032,7 @@ def __init__(
10261032
dtype=dtype,
10271033
weights_dtype=weights_dtype,
10281034
precision=precision,
1035+
sharding_specs=sharding_specs,
10291036
)
10301037
)
10311038

@@ -1059,6 +1066,7 @@ def __init__(
10591066
use_additional_conditions=False,
10601067
dtype=dtype,
10611068
weights_dtype=weights_dtype,
1069+
sharding_specs=sharding_specs,
10621070
)
10631071
)
10641072
else:
@@ -1155,6 +1163,7 @@ def __init__(
11551163
dtype: jnp.dtype = jnp.float32,
11561164
weights_dtype: jnp.dtype = jnp.float32,
11571165
precision: jax.lax.Precision = None,
1166+
sharding_specs: Optional[Any] = None,
11581167
):
11591168
self.encoder = LTX2VideoEncoder3d(
11601169
in_channels=in_channels,
@@ -1196,6 +1205,7 @@ def __init__(
11961205
dtype=dtype,
11971206
weights_dtype=weights_dtype,
11981207
precision=precision,
1208+
sharding_specs=sharding_specs,
11991209
)
12001210

12011211
self.scaling_factor = scaling_factor

0 commit comments

Comments
 (0)