Skip to content

Commit d4bc454

Browse files
author
Charles Li
committed
Dump input/activation sharding info to json files
Using inspect to get call stacktrace Cmd to generate input_shardings.json files: python -m tests.utils.run_sharding_dump
1 parent 5a4a9c3 commit d4bc454

27 files changed

Lines changed: 2096 additions & 23 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def _maybe_shard_with_logical(self, inputs, logical_name):
459459
mesh=self.mesh,
460460
shard_mode=self.config.shard_mode,
461461
debug_sharding=self.config.debug_sharding,
462+
extra_stack_level=1,
462463
)
463464

464465
def _logical_to_mesh_axes(self, logical_name):

src/maxtext/layers/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes):
133133
mesh=self.mesh,
134134
rules=self.config.logical_axis_rules,
135135
debug_sharding=self.config.debug_sharding,
136+
extra_stack_level=1,
136137
)
137138

138139
def _maybe_shard_with_name(self, inputs, sharding_name):

src/maxtext/models/deepseek.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,20 @@ def with_logical_constraint(self, x):
184184
mesh=self.mesh,
185185
shard_mode=self.config.shard_mode,
186186
debug_sharding=self.config.debug_sharding,
187+
extra_stack_level=1,
187188
)
188189

189190
def dropout_op(self, x, deterministic):
190-
return self.with_logical_constraint(self.dropout(x, deterministic=deterministic))
191+
dropout = self.dropout(x, deterministic=deterministic)
192+
return self.with_logical_constraint(dropout)
191193

192194
def pre_attention_norm_op(self, x):
193-
return self.with_logical_constraint(self.pre_self_attention_layer_norm(x))
195+
pre_attention_norm = self.pre_self_attention_layer_norm(x)
196+
return self.with_logical_constraint(pre_attention_norm)
194197

195198
def post_attention_norm_op(self, x):
196-
return self.with_logical_constraint(self.post_self_attention_layer_norm(x))
199+
post_attention_norm = self.post_self_attention_layer_norm(x)
200+
return self.with_logical_constraint(post_attention_norm)
197201

198202
def attention_op(
199203
self,
@@ -332,9 +336,8 @@ def __init__(
332336
)
333337

334338
def mlp_op(self, x, deterministic):
335-
return self.with_logical_constraint(
336-
self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding)
337-
)
339+
mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding)
340+
return self.with_logical_constraint(mlp)
338341

339342
def __call__(
340343
self,

src/maxtext/models/llama2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(
133133
mesh=self.mesh,
134134
shard_mode=config.shard_mode,
135135
debug_sharding=config.debug_sharding,
136+
extra_stack_level=1,
136137
)
137138

138139
def __call__(

src/maxtext/utils/sharding.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
from maxtext.utils import max_logging
3030
from maxtext.utils import max_utils
3131

32+
import inspect # for debugging only
33+
from pathlib import Path
3234

3335
_LOGGED_ACTIVATION_SHARDINGS = set()
34-
_LOGGED_LOGICAL_AXES = set()
36+
_ACTIVATION_SHARDINGS_DUMP = []
3537

3638

3739
def get_input_data_sharding(config, mesh):
@@ -45,51 +47,92 @@ def get_input_data_sharding(config, mesh):
4547
return data_sharding
4648

4749

48-
def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):
50+
def _get_sharding_desc(inputs, extra_stack_level):
51+
"""Get the inputs sharding description using inspect module"""
52+
frame = inspect.currentframe()
53+
# Traverse back extra_stack_level times:
54+
for _ in range(1 + extra_stack_level):
55+
if frame is not None:
56+
frame = frame.f_back
57+
if frame is not None:
58+
callers_local_vars = frame.f_locals.items()
59+
60+
x = [var_name for var_name, var_val in callers_local_vars if var_val is inputs]
61+
if len(x) > 0:
62+
caller_path_full = inspect.stack()[1 + extra_stack_level].filename
63+
# Use pathlib.Path to easily extract just the filename from the full path.
64+
caller_filename = Path(caller_path_full).name
65+
return f"{caller_filename[:-3]}/{x[0]}"
66+
return "Unknown"
67+
68+
69+
def maybe_shard_with_name(
70+
inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0, sharding_desc="", logical_axes=None
71+
):
4972
"""
5073
In auto shardmode, this function hints inputs follow given named_sharding.
5174
In explicit shardmode, this function enforces inputs following named_sharding.
75+
sharding_desc is description of inputs of upper layer(s) of caller (with the form of <filename>/<variable>).
76+
It is used as key in log/dump files when debug_sharding==true
5277
"""
5378
if inputs is None:
5479
return None
5580
if (
5681
debug_sharding and isinstance(inputs, Tracer) and isinstance(named_sharding, NamedSharding)
5782
): # only print pspec for JitTracer
83+
if not sharding_desc:
84+
sharding_desc = _get_sharding_desc(inputs, extra_stack_level + 1)
85+
86+
if not logical_axes:
87+
logical_axes = "Unknown"
88+
elif isinstance(logical_axes, list):
89+
logical_axes = tuple(logical_axes)
90+
5891
pspec = remove_size_one_mesh_axis(getattr(named_sharding, "spec"), getattr(named_sharding, "mesh"))
59-
log_key = (str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
92+
log_key = (sharding_desc, str(jax.typeof(inputs)), tuple(pspec), extra_stack_level)
6093
if log_key not in _LOGGED_ACTIVATION_SHARDINGS:
61-
max_logging.info(f"Physical: {log_key[0]:.<80} {log_key[1]}.", stacklevel=3 + extra_stack_level)
94+
max_logging.info(f"{sharding_desc} Logical: {log_key[1]:.<60} {logical_axes}.", stacklevel=3 + extra_stack_level)
95+
max_logging.info(f"{sharding_desc} Physical: {log_key[1]:.<60} {log_key[2]}.", stacklevel=3 + extra_stack_level)
6296
_LOGGED_ACTIVATION_SHARDINGS.add(log_key)
97+
98+
_ACTIVATION_SHARDINGS_DUMP.append(
99+
{
100+
f"{sharding_desc}: {log_key[1]}": {
101+
"logic_axes": f"{logical_axes}",
102+
"PartitionSpec": f"P{log_key[2]}",
103+
}
104+
}
105+
)
63106
if shard_mode == ShardMode.EXPLICIT:
64107
return reshard(inputs, named_sharding)
65108
else:
66109
return jax.lax.with_sharding_constraint(inputs, named_sharding)
67110

68111

69112
def maybe_shard_with_logical(
70-
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0
113+
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc=""
71114
):
72115
"""
73116
A wrapper of maybe_shard_with_name when logical axes are inputs
117+
sharding_desc is description of inputs of upper layer(s) of caller (with the form of <filename>/<variable>).
118+
It is used as key in log/dump files when debug_sharding==true
74119
"""
75120
if inputs is None:
76121
return None
77122

78-
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
79-
80-
if debug_sharding and isinstance(inputs, Tracer):
81-
log_key = (str(jax.typeof(inputs)), tuple(logical_axes), extra_stack_level)
123+
if debug_sharding and not sharding_desc:
124+
sharding_desc = _get_sharding_desc(inputs, extra_stack_level + 1)
82125

83-
if log_key not in _LOGGED_LOGICAL_AXES:
84-
max_logging.info(f"Logical: {log_key[0]:.<60} {log_key[1]}", stacklevel=3 + extra_stack_level)
85-
_LOGGED_LOGICAL_AXES.add(log_key)
126+
named_sharding = create_sharding(mesh, logical_axes, rules=rules)
86127

87128
return maybe_shard_with_name(
88129
inputs,
89130
named_sharding,
90131
shard_mode,
91132
debug_sharding=debug_sharding,
92133
extra_stack_level=extra_stack_level + 1,
134+
sharding_desc=sharding_desc,
135+
logical_axes=logical_axes,
93136
)
94137

95138

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def vocab_tiling_linen_loss(
8989
)
9090

9191
_maybe_shard_with_name = functools.partial(
92-
maybe_shard_with_name, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding
92+
maybe_shard_with_name,
93+
shard_mode=config.shard_mode,
94+
debug_sharding=config.debug_sharding,
95+
extra_stack_level=1,
9396
)
9497

9598
def _reshape(inputs, out_shape, out_sharding):

0 commit comments

Comments
 (0)