2929from maxtext .utils import max_logging
3030from maxtext .utils import max_utils
3131
32+ import inspect # for debugging only
33+ from pathlib import Path
3234
3335_LOGGED_ACTIVATION_SHARDINGS = set ()
34- _LOGGED_LOGICAL_AXES = set ()
36+ _ACTIVATION_SHARDINGS_DUMP = []
3537
3638
3739def get_input_data_sharding (config , mesh ):
@@ -45,51 +47,92 @@ def get_input_data_sharding(config, mesh):
4547 return data_sharding
4648
4749
48- def maybe_shard_with_name (inputs , named_sharding , shard_mode , debug_sharding = False , extra_stack_level = 0 ):
50+ def _get_sharding_desc (inputs , extra_stack_level ):
51+ """Get the inputs sharding description using inspect module"""
52+ frame = inspect .currentframe ()
53+ # Traverse back extra_stack_level times:
54+ for _ in range (1 + extra_stack_level ):
55+ if frame is not None :
56+ frame = frame .f_back
57+ if frame is not None :
58+ callers_local_vars = frame .f_locals .items ()
59+
60+ x = [var_name for var_name , var_val in callers_local_vars if var_val is inputs ]
61+ if len (x ) > 0 :
62+ caller_path_full = inspect .stack ()[1 + extra_stack_level ].filename
63+ # Use pathlib.Path to easily extract just the filename from the full path.
64+ caller_filename = Path (caller_path_full ).name
65+ return f"{ caller_filename [:- 3 ]} /{ x [0 ]} "
66+ return "Unknown"
67+
68+
69+ def maybe_shard_with_name (
70+ inputs , named_sharding , shard_mode , debug_sharding = False , extra_stack_level = 0 , sharding_desc = "" , logical_axes = None
71+ ):
4972 """
5073 In auto shardmode, this function hints inputs follow given named_sharding.
5174 In explicit shardmode, this function enforces inputs following named_sharding.
75+ sharding_desc is description of inputs of upper layer(s) of caller (with the form of <filename>/<variable>).
76+ It is used as key in log/dump files when debug_sharding==true
5277 """
5378 if inputs is None :
5479 return None
5580 if (
5681 debug_sharding and isinstance (inputs , Tracer ) and isinstance (named_sharding , NamedSharding )
5782 ): # only print pspec for JitTracer
83+ if not sharding_desc :
84+ sharding_desc = _get_sharding_desc (inputs , extra_stack_level + 1 )
85+
86+ if not logical_axes :
87+ logical_axes = "Unknown"
88+ elif isinstance (logical_axes , list ):
89+ logical_axes = tuple (logical_axes )
90+
5891 pspec = remove_size_one_mesh_axis (getattr (named_sharding , "spec" ), getattr (named_sharding , "mesh" ))
59- log_key = (str (jax .typeof (inputs )), tuple (pspec ), extra_stack_level )
92+ log_key = (sharding_desc , str (jax .typeof (inputs )), tuple (pspec ), extra_stack_level )
6093 if log_key not in _LOGGED_ACTIVATION_SHARDINGS :
61- max_logging .info (f"Physical: { log_key [0 ]:.<80} { log_key [1 ]} ." , stacklevel = 3 + extra_stack_level )
94+ max_logging .info (f"{ sharding_desc } Logical: { log_key [1 ]:.<60} { logical_axes } ." , stacklevel = 3 + extra_stack_level )
95+ max_logging .info (f"{ sharding_desc } Physical: { log_key [1 ]:.<60} { log_key [2 ]} ." , stacklevel = 3 + extra_stack_level )
6296 _LOGGED_ACTIVATION_SHARDINGS .add (log_key )
97+
98+ _ACTIVATION_SHARDINGS_DUMP .append (
99+ {
100+ f"{ sharding_desc } : { log_key [1 ]} " : {
101+ "logic_axes" : f"{ logical_axes } " ,
102+ "PartitionSpec" : f"P{ log_key [2 ]} " ,
103+ }
104+ }
105+ )
63106 if shard_mode == ShardMode .EXPLICIT :
64107 return reshard (inputs , named_sharding )
65108 else :
66109 return jax .lax .with_sharding_constraint (inputs , named_sharding )
67110
68111
69112def maybe_shard_with_logical (
70- inputs , logical_axes , mesh , shard_mode , rules = None , debug_sharding = False , extra_stack_level = 0
113+ inputs , logical_axes , mesh , shard_mode , rules = None , debug_sharding = False , extra_stack_level = 0 , sharding_desc = ""
71114):
72115 """
73116 A wrapper of maybe_shard_with_name when logical axes are inputs
117+ sharding_desc is description of inputs of upper layer(s) of caller (with the form of <filename>/<variable>).
118+ It is used as key in log/dump files when debug_sharding==true
74119 """
75120 if inputs is None :
76121 return None
77122
78- named_sharding = create_sharding (mesh , logical_axes , rules = rules )
79-
80- if debug_sharding and isinstance (inputs , Tracer ):
81- log_key = (str (jax .typeof (inputs )), tuple (logical_axes ), extra_stack_level )
123+ if debug_sharding and not sharding_desc :
124+ sharding_desc = _get_sharding_desc (inputs , extra_stack_level + 1 )
82125
83- if log_key not in _LOGGED_LOGICAL_AXES :
84- max_logging .info (f"Logical: { log_key [0 ]:.<60} { log_key [1 ]} " , stacklevel = 3 + extra_stack_level )
85- _LOGGED_LOGICAL_AXES .add (log_key )
126+ named_sharding = create_sharding (mesh , logical_axes , rules = rules )
86127
87128 return maybe_shard_with_name (
88129 inputs ,
89130 named_sharding ,
90131 shard_mode ,
91132 debug_sharding = debug_sharding ,
92133 extra_stack_level = extra_stack_level + 1 ,
134+ sharding_desc = sharding_desc ,
135+ logical_axes = logical_axes ,
93136 )
94137
95138
0 commit comments