Skip to content

Commit e0f262d

Browse files
committed
Update
1 parent 702bd6d commit e0f262d

2 files changed

Lines changed: 21 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,27 @@ class TextEncoderShardingSpecs:
6060
text_encoder_kernel: Optional[tuple] = None
6161

6262

63+
@dataclass
64+
class TextConnectorShardingSpecs:
65+
"""Specs for the Text Connector execution."""
66+
67+
net_0_kernel: tuple = (None, "mlp")
68+
net_0_bias: tuple = ("mlp",)
69+
net_2_kernel: tuple = ("mlp", None)
70+
net_2_bias: tuple = (None,)
71+
72+
6373
@dataclass
6474
class VAEShardingSpecs:
6575
"""Sharding specs for the VAE."""
6676

6777
vae_conv_kernel: Optional[tuple] = None
6878
force_replication: bool = True
79+
# Specs for NNXTimestepEmbedding used in VAE
80+
emb_linear_1_kernel: tuple = ("embed", "mlp")
81+
emb_linear_1_bias: tuple = ("mlp",)
82+
emb_linear_2_kernel: tuple = ("mlp", "embed")
83+
emb_linear_2_bias: tuple = ("embed",)
6984

7085

7186
# --- Unified Registry for LTX2 ---
@@ -80,6 +95,7 @@ class VAEShardingSpecs:
8095
use_batched_text_encoder=True,
8196
text_encoder_kernel=(None, "embed"),
8297
),
98+
"text_connector": TextConnectorShardingSpecs(),
8399
"vae": VAEShardingSpecs(vae_conv_kernel=("batch", None, None, None)),
84100
},
85101
"trillium": {
@@ -92,6 +108,7 @@ class VAEShardingSpecs:
92108
use_batched_text_encoder=False,
93109
text_encoder_kernel=(None, "embed"),
94110
),
111+
"text_connector": TextConnectorShardingSpecs(),
95112
"vae": VAEShardingSpecs(vae_conv_kernel=(None, None, None, None)),
96113
},
97114
}

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Tuple, Union, List
17+
from typing import Tuple, Union, List, Any
1818
import jax
1919
import jax.numpy as jnp
2020
from flax import nnx
@@ -57,6 +57,7 @@ def __init__(
5757
attention_kernel: str = "flash",
5858
mesh: jax.sharding.Mesh = None,
5959
rngs: nnx.Rngs = None,
60+
sharding_specs: Optional[Any] = None,
6061
**kwargs,
6162
):
6263
input_dim = caption_channels * text_proj_in_factor
@@ -82,6 +83,7 @@ def __init__(
8283
attention_kernel=attention_kernel,
8384
mesh=mesh,
8485
rngs=rngs,
86+
sharding_specs=sharding_specs,
8587
)
8688

8789
self.audio_embeddings_connector = Embeddings1DConnector(
@@ -97,6 +99,7 @@ def __init__(
9799
attention_kernel=attention_kernel,
98100
mesh=mesh,
99101
rngs=rngs,
102+
sharding_specs=sharding_specs,
100103
)
101104

102105
def __call__(

0 commit comments

Comments
 (0)