Skip to content

Commit 10ae8f3

Browse files
committed
fix in ltx2_pipeline
1 parent 5475850 commit 10ae8f3

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def cast_with_exclusion(path, x, dtype_to_cast):
9292
return x.astype(dtype_to_cast)
9393

9494

95-
def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState:
95+
def _add_sharding_rule(vs: nnx.Variable, logical_axis_rules) -> nnx.Variable:
9696
vs.sharding_rules = logical_axis_rules
9797
return vs
9898

0 commit comments

Comments
 (0)