@@ -54,20 +54,20 @@ def calibrate_fbs(config):
5454
5555 print (f"Creating model with flash_block_sizes: { ltx2_config_dict ['flash_block_sizes' ]} " )
5656
57+ print (f"Loading Sharded Transformer using LTX2Pipeline.load_transformer..." )
58+ from maxdiffusion .pipelines .ltx2 .ltx2_pipeline import LTX2Pipeline
59+
5760 with mesh :
58- # Standard initialization
59- transformer = LTX2VideoTransformer3DModel (** ltx2_config_dict , rngs = rngs )
60-
61- # Shard the model
62- graphdef , state , rest_of_state = nnx .split (transformer , nnx .Param , ...)
63- def _add_sharding_rule (vs : nnx .VariableState , logical_axis_rules ):
64- vs .sharding_rules = logical_axis_rules
65- return vs
66-
67- p_add_sharding_rule = partial (_add_sharding_rule , logical_axis_rules = config .logical_axis_rules )
68- state_sharded = jax .tree .map (p_add_sharding_rule , state , is_leaf = lambda x : isinstance (x , nnx .VariableState ))
69- pspecs = nnx .get_partition_spec (state_sharded )
70- sharded_state = jax .lax .with_sharding_constraint (state_sharded , pspecs )
61+ # Load transformer via the robust HF sharded logical mechanism to bypass 16GB Single-Device Allocation Limit
62+ transformer = LTX2Pipeline .load_transformer (
63+ devices_array = devices_array ,
64+ mesh = mesh ,
65+ rngs = rngs ,
66+ config = config ,
67+ restored_checkpoint = None ,
68+ subfolder = "transformer" ,
69+ )
70+ graphdef , sharded_state , rest_of_state = nnx .split (transformer , nnx .Param , ...)
7171
7272 from maxdiffusion .pipelines .ltx2 .ltx2_pipeline import transformer_forward_pass
7373
0 commit comments