Skip to content

Commit f130e8e

Browse files
Cristian GarciaGoogle-ML-Automation
authored andcommitted
Rename sharding_names to out_sharding in NNX Variable metadata
This CL renames the sharding_names attribute to out_sharding for better consistency with the sharding API. The new name more clearly indicates the purpose of this metadata field. ## Changes - Bump Flax version to 0.12.4 - Core changes in variablelib.py: - Add sharding_names to out_sharding metadata remapping for backward compatibility - Add deprecated sharding_names property that returns out_sharding with a warning - Update nnx/spmd.py, core/spmd.py, core/meta.py, linen/spmd.py to use out_sharding - Update all NNX tests to use the new attribute name - Update qwix flax_util.py to check for out_sharding first, with fallback to sharding_names - Update maxtext initializers.py to check for out_sharding first - Update documentation and examples to use out_sharding ## Backward Compatibility Existing code using sharding_names will continue to work via: - Metadata remapping during Variable creation - Deprecated Variable.sharding_names property PiperOrigin-RevId: 869269899
1 parent b6e0cdb commit f130e8e

1 file changed

Lines changed: 13 additions & 9 deletions

File tree

src/MaxText/layers/initializers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
6464
"""Wraps an NNX variable's value in `nn.LogicallyPartitioned`.
6565
6666
This function inspects the metadata of an `nnx.VariableState` object. If
67-
sharding information ('sharding' or 'sharding_names') is present, it wraps
68-
the variable's value in `nn.LogicallyPartitioned` to apply the specified
69-
sharding constraints.
67+
sharding information ('out_sharding', 'sharding' or 'sharding_names') is
68+
present, it wraps the variable's value in `nn.LogicallyPartitioned` to apply
69+
the specified sharding constraints.
7070
7171
It handles special cases for `aqt_tensor.QTensor` and variables of type
7272
`_overwrite_with_gradient` by returning their values directly without
@@ -85,14 +85,18 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
8585
return variable.value
8686

8787
metadata = variable.get_metadata()
88-
if "sharding" in metadata or "sharding_names" in metadata:
89-
if "sharding_names" in metadata:
90-
sharding_names = metadata["sharding_names"]
91-
else:
92-
sharding_names = metadata["sharding"]
88+
out_sharding = None
89+
if "out_sharding" in metadata:
90+
out_sharding = metadata["out_sharding"]
91+
elif "sharding_names" in metadata:
92+
out_sharding = metadata["sharding_names"]
93+
elif "sharding" in metadata:
94+
out_sharding = metadata["sharding"]
95+
96+
if out_sharding is not None:
9397
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
9498
variable.value,
95-
sharding_names, # type: ignore[arg-type]
99+
out_sharding, # type: ignore[arg-type]
96100
mesh=metadata.get("mesh"),
97101
rules=metadata.get("rules"),
98102
)

0 commit comments

Comments
 (0)