@@ -944,6 +944,15 @@ def setup_initial_state(
944944 return state , state_mesh_annotations , state_mesh_shardings , data_iterator
945945
946946
947+ def get_logical_annotations (model , tx , config , rng , mesh , is_training = True ):
948+ init_state_partial = functools .partial (init_initial_state , model , tx , config , is_training , rng )
949+
950+ with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
951+ abstract_state = jax .eval_shape (init_state_partial )
952+ logical_annotations = nn .get_partition_spec (abstract_state )
953+ return logical_annotations
954+
955+
947956def get_abstract_state (model , tx , config , rng , mesh , is_training = True ):
948957 """Get a shaped abstraction of the state (including optimizer)"""
949958 init_state_partial = functools .partial (init_initial_state , model , tx , config , is_training , rng )
@@ -1227,15 +1236,32 @@ def schedule(step):
12271236 return optax .join_schedules (pieces , boundaries )
12281237
12291238
1230- def print_shardings_params (params , params_sharding , mesh ):
1231- """Print state shardings."""
1239+ def print_shardings_params (params , params_sharding , mesh , logical_annotations = None ):
1240+ """
1241+ Print state shardings comparing Logical Definition vs Physical Result.
1242+ """
1243+ if not hasattr (params , "params" ):
1244+ params = {"params" : params }
1245+ if not hasattr (params_sharding , "params" ):
1246+ params_sharding = {"params" : params_sharding }
1247+ if logical_annotations and not hasattr (logical_annotations , "params" ):
1248+ logical_annotations = {"params" : logical_annotations }
1249+
12321250 leaves_params , _ = jax .tree_util .tree_flatten_with_path (params )
12331251 leaves_sharding , _ = jax .tree_util .tree_flatten_with_path (params_sharding )
1234- for (path , leaf_val ), (_ , leaf_sharding ) in zip (leaves_params , leaves_sharding ):
1252+ leaves_logical , _ = jax .tree_util .tree_flatten_with_path (logical_annotations )
1253+
1254+ for (path , leaf_val ), (_ , leaf_sharding ), (_ , leaf_logical_val ) in zip (leaves_params , leaves_sharding , leaves_logical ):
12351255 path_str = "/" .join (str (p .key if hasattr (p , "key" ) else p .name ) for p in path )
12361256 shape = jax .typeof (leaf_val )
12371257 pspec = sharding .remove_size_one_mesh_axis (leaf_sharding .spec , mesh )
1238- max_logging .log (f"{ path_str :.<80} { shape } { tuple (pspec )} " )
1258+ pspec_str = str (tuple (pspec ))
1259+ logical_str = str (leaf_logical_val )
1260+
1261+ message = f" { path_str } \n " f" Shape: { shape } \n " f" Logical: { logical_str } \n " f" Physical: { pspec_str } "
1262+ max_logging .info (message )
1263+
1264+ print (flush = True )
12391265
12401266
12411267def maybe_dump_jaxpr (config , p_train_step , train_step_inputs ):
0 commit comments