@@ -60,12 +60,27 @@ class TextEncoderShardingSpecs:
6060 text_encoder_kernel : Optional [tuple ] = None
6161
6262
63+ @dataclass
64+ class TextConnectorShardingSpecs :
65+ """Specs for the Text Connector execution."""
66+
67+ net_0_kernel : tuple = (None , "mlp" )
68+ net_0_bias : tuple = ("mlp" ,)
69+ net_2_kernel : tuple = ("mlp" , None )
70+ net_2_bias : tuple = (None ,)
71+
72+
6373@dataclass
6474class VAEShardingSpecs :
6575 """Sharding specs for the VAE."""
6676
6777 vae_conv_kernel : Optional [tuple ] = None
6878 force_replication : bool = True
79+ # Specs for NNXTimestepEmbedding used in VAE
80+ emb_linear_1_kernel : tuple = ("embed" , "mlp" )
81+ emb_linear_1_bias : tuple = ("mlp" ,)
82+ emb_linear_2_kernel : tuple = ("mlp" , "embed" )
83+ emb_linear_2_bias : tuple = ("embed" ,)
6984
7085
7186# --- Unified Registry for LTX2 ---
@@ -80,6 +95,7 @@ class VAEShardingSpecs:
8095 use_batched_text_encoder = True ,
8196 text_encoder_kernel = (None , "embed" ),
8297 ),
98+ "text_connector" : TextConnectorShardingSpecs (),
8399 "vae" : VAEShardingSpecs (vae_conv_kernel = ("batch" , None , None , None )),
84100 },
85101 "trillium" : {
@@ -92,6 +108,7 @@ class VAEShardingSpecs:
92108 use_batched_text_encoder = False ,
93109 text_encoder_kernel = (None , "embed" ),
94110 ),
111+ "text_connector" : TextConnectorShardingSpecs (),
95112 "vae" : VAEShardingSpecs (vae_conv_kernel = (None , None , None , None )),
96113 },
97114}
0 commit comments