Skip to content

Commit fcb87c0

Browse files
Merge pull request #2985 from CIeNET-International:user/sharony/logicaxes
PiperOrigin-RevId: 862298007
2 parents 339dc7e + 87a6b93 commit fcb87c0

7 files changed

Lines changed: 82 additions & 14 deletions

File tree

src/MaxText/maxtext_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,15 @@ def setup_initial_state(
944944
return state, state_mesh_annotations, state_mesh_shardings, data_iterator
945945

946946

947+
def get_logical_annotations(model, tx, config, rng, mesh, is_training=True):
948+
init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng)
949+
950+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
951+
abstract_state = jax.eval_shape(init_state_partial)
952+
logical_annotations = nn.get_partition_spec(abstract_state)
953+
return logical_annotations
954+
955+
947956
def get_abstract_state(model, tx, config, rng, mesh, is_training=True):
948957
"""Get a shaped abstraction of the state (including optimizer)"""
949958
init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng)
@@ -1227,15 +1236,32 @@ def schedule(step):
12271236
return optax.join_schedules(pieces, boundaries)
12281237

12291238

1230-
def print_shardings_params(params, params_sharding, mesh):
1231-
"""Print state shardings."""
1239+
def print_shardings_params(params, params_sharding, mesh, logical_annotations=None):
1240+
"""
1241+
Print state shardings comparing Logical Definition vs Physical Result.
1242+
"""
1243+
if not hasattr(params, "params"):
1244+
params = {"params": params}
1245+
if not hasattr(params_sharding, "params"):
1246+
params_sharding = {"params": params_sharding}
1247+
if logical_annotations and not hasattr(logical_annotations, "params"):
1248+
logical_annotations = {"params": logical_annotations}
1249+
12321250
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
12331251
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
1234-
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1252+
leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations)
1253+
1254+
for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical):
12351255
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
12361256
shape = jax.typeof(leaf_val)
12371257
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
1238-
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1258+
pspec_str = str(tuple(pspec))
1259+
logical_str = str(leaf_logical_val)
1260+
1261+
message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}"
1262+
max_logging.info(message)
1263+
1264+
print(flush=True)
12391265

12401266

12411267
def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):

src/MaxText/model_creation_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,12 @@ def create_sharded_state():
157157
# print weights sharding info under debug sharding mode
158158
if config.debug_sharding:
159159
max_utils.print_non_trivial_mesh_axis(model.mesh)
160-
maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh)
160+
maxtext_utils.print_shardings_params(
161+
params=sharded_state,
162+
params_sharding=out_shardings,
163+
mesh=model.mesh,
164+
logical_annotations=specs,
165+
)
161166
if config.load_parameters_path:
162167
try:
163168
ckptr = ocp.Checkpointer(

src/MaxText/sharding.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232

3333
_LOGGED_ACTIVATION_SHARDINGS = set()
34+
_LOGGED_LOGICAL_AXES = set()
3435

3536

3637
def get_input_data_sharding(config, mesh):
@@ -51,7 +52,7 @@ def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=Fal
5152
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
5253
log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
5354
if log_key not in _LOGGED_ACTIVATION_SHARDINGS:
54-
max_logging.info(f"{log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
55+
max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
5556
_LOGGED_ACTIVATION_SHARDINGS.add(log_key)
5657
if shard_mode == ShardMode.EXPLICIT:
5758
return reshard(inputs, named_sharding)
@@ -67,9 +68,22 @@ def maybe_shard_with_logical(
6768
"""
6869
if inputs is None:
6970
return None
71+
7072
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
73+
74+
if debug_sharding and isinstance(inputs, Tracer):
75+
log_key = (str(jax.typeof(inputs)), logical_axes, extra_stack_level)
76+
77+
if log_key not in _LOGGED_LOGICAL_AXES:
78+
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
79+
_LOGGED_LOGICAL_AXES.add(log_key)
80+
7181
return maybe_shard_with_name(
72-
inputs, named_sharding, shard_mode, debug_sharding=debug_sharding, extra_stack_level=extra_stack_level + 1
82+
inputs,
83+
named_sharding,
84+
shard_mode,
85+
debug_sharding=debug_sharding,
86+
extra_stack_level=extra_stack_level + 1,
7387
)
7488

7589

src/MaxText/train_compile.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,15 @@ def get_shaped_inputs(topology_mesh, config):
104104
model, tx, config, example_rng, topology_mesh
105105
)
106106

107+
# unsharded logical annotations
108+
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh)
109+
107110
# Shaped batch
108111
shaped_batch = maxtext_utils.get_shaped_batch(config)
109112

110113
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
111114
shaped_train_kwargs = {}
112-
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model
115+
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model
113116

114117

115118
def jit_and_compile(
@@ -160,7 +163,13 @@ def is_oom(argv: Sequence[str]) -> bool:
160163
max_utils.print_system_information()
161164

162165
# Get shaped inputs
163-
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
166+
(
167+
shaped_train_args,
168+
shaped_train_kwargs,
169+
state_mesh_shardings,
170+
_,
171+
model,
172+
) = get_shaped_inputs(topology_mesh, config)
164173

165174
# Get data sharding
166175
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -216,7 +225,13 @@ def main(argv: Sequence[str]) -> None:
216225
max_utils.print_system_information()
217226

218227
# Get shaped inputs
219-
shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config)
228+
(
229+
shaped_train_args,
230+
shaped_train_kwargs,
231+
state_mesh_shardings,
232+
logical_annotations,
233+
model,
234+
) = get_shaped_inputs(topology_mesh, config)
220235

221236
# Get data sharding
222237
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
@@ -231,7 +246,12 @@ def main(argv: Sequence[str]) -> None:
231246
# print weights sharding info under debug sharding mode
232247
if config.debug_sharding:
233248
max_utils.print_non_trivial_mesh_axis(topology_mesh)
234-
maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh)
249+
maxtext_utils.print_shardings_params(
250+
shaped_train_args[0].params,
251+
state_mesh_shardings.params,
252+
topology_mesh,
253+
logical_annotations.params,
254+
)
235255

236256
# Compile
237257
print("Jitting and compiling train step...", flush=True)

src/MaxText/train_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,11 @@ def setup_train_loop(config, recorder, devices=None):
217217

218218
# print weights sharding info under debug sharding mode
219219
if config.debug_sharding:
220+
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True)
220221
max_utils.print_non_trivial_mesh_axis(model.mesh)
221-
maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh)
222+
maxtext_utils.print_shardings_params(
223+
state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params
224+
)
222225

223226
if config.use_dpo:
224227
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)

tests/unit/sharding_compare_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
9797
validate_config(config)
9898

9999
topology_mesh = get_topology_mesh(config)
100-
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
100+
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
101101
actual_json = named_shardings_to_json(state_mesh_shardings)
102102
expected_json = load_named_sharding_json(json_path)
103103

tests/utils/sharding_dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def main(argv: Sequence[str]) -> None:
276276

277277
try:
278278
topology_mesh = get_topology_mesh(config)
279-
_, _, state_mesh_shardings, _ = get_shaped_inputs(topology_mesh, config)
279+
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
280280
except: # pylint: disable=bare-except
281281
state_mesh_shardings = {}
282282

0 commit comments

Comments
 (0)