@@ -329,13 +329,18 @@ def load_connectors(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, co
329329 max_logging .log ("Loading Connectors..." )
330330
331331 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
332+ sharding_config = getattr (config , "sharding" , {})
333+ connector_strategy = sharding_config .get ("text_connector" , "default" )
334+ connector_specs = get_sharding_specs (connector_strategy , "text_connector" )
335+
332336 connectors = LTX2AudioVideoGemmaTextEncoder .from_config (
333337 config .pretrained_model_name_or_path ,
334338 subfolder = "connectors" ,
335339 rngs = rngs ,
336340 mesh = mesh ,
337341 dtype = jnp .float32 ,
338342 weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
343+ sharding_specs = connector_specs ,
339344 )
340345 return connectors
341346
@@ -371,13 +376,18 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
371376 max_logging .log ("Loading Video VAE..." )
372377
373378 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
379+ sharding_config = getattr (config , "sharding" , {})
380+ vae_strategy = sharding_config .get ("vae" , "default" )
381+ vae_specs = get_sharding_specs (vae_strategy , "vae" )
382+
374383 vae = LTX2VideoAutoencoderKL .from_config (
375384 config .pretrained_model_name_or_path ,
376385 subfolder = "vae" ,
377386 rngs = rngs ,
378387 mesh = mesh ,
379388 dtype = jnp .float32 ,
380389 weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
390+ sharding_specs = vae_specs ,
381391 )
382392 return vae
383393
0 commit comments