Skip to content

Commit 0216b16

Browse files
committed
feat(ltx2): add defaults to LTX2DiTShardingSpecs matching Trillium
1 parent 522a478 commit 0216b16

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class LTX2DiTShardingSpecs:
2525
"""Sharding specs for the LTX2 Diffusion Transformer."""
2626

2727
# --- Attention Layers (LTX2Attention) ---
28-
qkv_kernel: tuple
29-
out_kernel: tuple
30-
out_bias: tuple
28+
qkv_kernel: tuple = ("embed", "heads")
29+
out_kernel: tuple = ("heads", "embed")
30+
out_bias: tuple = ("embed",)
3131
qkv_bias: tuple = ("heads",)
3232

3333
# --- Feed Forward Network (NNXSimpleFeedForward) ---
@@ -113,11 +113,7 @@ class VAEShardingSpecs:
113113
"vae": VAEShardingSpecs(),
114114
},
115115
"trillium": {
116-
"ltx2_dit": LTX2DiTShardingSpecs(
117-
qkv_kernel=("embed", "heads"),
118-
out_kernel=("heads", "embed"),
119-
out_bias=("embed",),
120-
),
116+
"ltx2_dit": LTX2DiTShardingSpecs(),
121117
"text_encoder": TextEncoderShardingSpecs(
122118
use_batched_text_encoder=False,
123119
text_encoder_kernel=(None, "embed"),

0 commit comments

Comments
 (0)