@@ -56,19 +56,21 @@ class LTX2DiTShardingSpecs:
5656class TextEncoderShardingSpecs :
5757 """Specs for the Text Encoder execution."""
5858
59- use_batched_text_encoder : bool = False
60- text_encoder_kernel : Optional [tuple ] = None
59+ use_batched_text_encoder : bool = False # Flag to batch text encoder execution
60+ text_encoder_kernel : Optional [tuple ] = None # Spec for text encoder projection layers
6161
6262
6363@dataclass
6464class TextConnectorShardingSpecs :
6565 """Specs for the Text Connector execution."""
6666
67+ # --- MLP Specs (NNXSimpleFeedForward) ---
6768 net_0_kernel : tuple = (None , "mlp" )
6869 net_0_bias : tuple = ("mlp" ,)
6970 net_2_kernel : tuple = ("mlp" , None )
7071 net_2_bias : tuple = (None ,)
71- # Attention specs (defaulting to trillium/safe defaults)
72+
73+ # --- Attention Specs (LTX2Attention) ---
7274 qkv_kernel : tuple = ("embed" , "heads" )
7375 out_kernel : tuple = ("heads" , "embed" )
7476 out_bias : tuple = ("embed" ,)
@@ -80,9 +82,11 @@ class TextConnectorShardingSpecs:
8082class VAEShardingSpecs :
8183 """Sharding specs for the VAE."""
8284
85+ # --- VAE Specific Specs ---
8386 vae_conv_kernel : Optional [tuple ] = None
8487 force_replication : bool = True
85- # Specs for NNXTimestepEmbedding used in VAE
88+
89+ # --- Shared Embeddings Specs (NNXTimestepEmbedding) ---
8690 emb_linear_1_kernel : tuple = ("embed" , "mlp" )
8791 emb_linear_1_bias : tuple = ("mlp" ,)
8892 emb_linear_2_kernel : tuple = ("mlp" , "embed" )
0 commit comments