Skip to content

Commit 001651b

Browse files
committed
ltx2 yaml change
1 parent 3408630 commit 001651b

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

src/maxdiffusion/scripts/calibrate_ltx2_fbs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,20 @@ def calibrate_fbs(config):
5454

5555
print(f"Creating model with flash_block_sizes: {ltx2_config_dict['flash_block_sizes']}")
5656

57+
print(f"Loading Sharded Transformer using LTX2Pipeline.load_transformer...")
58+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
59+
5760
with mesh:
58-
# Standard initialization
59-
transformer = LTX2VideoTransformer3DModel(**ltx2_config_dict, rngs=rngs)
60-
61-
# Shard the model
62-
graphdef, state, rest_of_state = nnx.split(transformer, nnx.Param, ...)
63-
def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules):
64-
vs.sharding_rules = logical_axis_rules
65-
return vs
66-
67-
p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules)
68-
state_sharded = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
69-
pspecs = nnx.get_partition_spec(state_sharded)
70-
sharded_state = jax.lax.with_sharding_constraint(state_sharded, pspecs)
61+
# Load transformer via the robust HF sharded logical mechanism to bypass 16GB Single-Device Allocation Limit
62+
transformer = LTX2Pipeline.load_transformer(
63+
devices_array=devices_array,
64+
mesh=mesh,
65+
rngs=rngs,
66+
config=config,
67+
restored_checkpoint=None,
68+
subfolder="transformer",
69+
)
70+
graphdef, sharded_state, rest_of_state = nnx.split(transformer, nnx.Param, ...)
7171

7272
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import transformer_forward_pass
7373

0 commit comments

Comments
 (0)