Skip to content

Commit 522a478

Browse files
committed
feat(ltx2): add section comments to all sharding spec classes
1 parent e2f4092 commit 522a478

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ class LTX2DiTShardingSpecs:
5656
class TextEncoderShardingSpecs:
5757
"""Specs for the Text Encoder execution."""
5858

59-
use_batched_text_encoder: bool = False
60-
text_encoder_kernel: Optional[tuple] = None
59+
use_batched_text_encoder: bool = False # Flag to batch text encoder execution
60+
text_encoder_kernel: Optional[tuple] = None # Spec for text encoder projection layers
6161

6262

6363
@dataclass
6464
class TextConnectorShardingSpecs:
6565
"""Specs for the Text Connector execution."""
6666

67+
# --- MLP Specs (NNXSimpleFeedForward) ---
6768
net_0_kernel: tuple = (None, "mlp")
6869
net_0_bias: tuple = ("mlp",)
6970
net_2_kernel: tuple = ("mlp", None)
7071
net_2_bias: tuple = (None,)
71-
# Attention specs (defaulting to trillium/safe defaults)
72+
73+
# --- Attention Specs (LTX2Attention) ---
7274
qkv_kernel: tuple = ("embed", "heads")
7375
out_kernel: tuple = ("heads", "embed")
7476
out_bias: tuple = ("embed",)
@@ -80,9 +82,11 @@ class TextConnectorShardingSpecs:
8082
class VAEShardingSpecs:
8183
"""Sharding specs for the VAE."""
8284

85+
# --- VAE Specific Specs ---
8386
vae_conv_kernel: Optional[tuple] = None
8487
force_replication: bool = True
85-
# Specs for NNXTimestepEmbedding used in VAE
88+
89+
# --- Shared Embeddings Specs (NNXTimestepEmbedding) ---
8690
emb_linear_1_kernel: tuple = ("embed", "mlp")
8791
emb_linear_1_bias: tuple = ("mlp",)
8892
emb_linear_2_kernel: tuple = ("mlp", "embed")

0 commit comments

Comments
 (0)