@@ -54,12 +54,61 @@ def find_fsdp(pspec):
5454
5555def generate_bsw_pps_from_pps (physical_partition_spec ):
5656 """Create bsw physical partition spec from weight physical partition spec."""
57- return jax .tree .map (
58- lambda pps : P (* remove_fsdp_from_physical_partition_spec (pps )[1 :]),
57+
58+ def _process_pps (path , pps ):
59+ # Extract string keys from the JAX KeyPath elements safely
60+ path_keys = [getattr (p , "key" , str (p )) for p in path ]
61+ is_moe_block_0 = "MoeBlock_0" in path_keys
62+
63+ # Remove the gathered axes conditionally based on the path
64+ processed_pps = remove_gathered_axes_from_physical_partition_spec (pps , is_moe_block_0 )
65+
66+ # Keep the original [1:] slicing behavior (e.g., to drop the 'stage' axis)
67+ return P (* processed_pps [1 :])
68+
69+ return jax .tree_util .tree_map_with_path (
70+ _process_pps ,
5971 physical_partition_spec ,
6072 )
6173
6274
75+ def remove_gathered_axes_from_physical_partition_spec (pps , is_moe_block_0 ):
76+ """Removes 'fsdp', 'fsdp_transpose', and conditionally 'expert' from a physical PartitionSpec."""
77+
78+ # Always remove fsdp and fsdp_transpose as they are always gathered
79+ axes_to_remove = ["fsdp" , "fsdp_transpose" ]
80+
81+ # Only remove 'expert' if we are NOT in MoeBlock_0
82+ if not is_moe_block_0 :
83+ axes_to_remove .append ("expert" )
84+
85+ if isinstance (pps , P ):
86+ new_spec = []
87+ # Iterate through each axis in the original PartitionSpec.
88+ for axis in pps :
89+ if axis is None :
90+ new_spec .append (None )
91+ elif isinstance (axis , str ):
92+ # If the axis is in our removal list, replace it with None to signify replication.
93+ if axis not in axes_to_remove :
94+ new_spec .append (axis )
95+ else :
96+ new_spec .append (None )
97+ elif isinstance (axis , (list , tuple )):
98+ # If the axis is a collection, filter out the gathered axes.
99+ new_axis = [a for a in axis if a not in axes_to_remove ]
100+ # If all elements are filtered out, new_axis becomes [], which as a tuple ()
101+ # correctly signals replication across those mesh axes in JAX.
102+ new_spec .append (tuple (new_axis ))
103+ else :
104+ raise ValueError (f"Unsupported_axis_type: { type (axis )} " )
105+
106+ # Return a new sharding object with the modified spec.
107+ return P (* new_spec )
108+
109+ return pps
110+
111+
63112def get_logical_spec_repeats_removed (full_logical ):
64113 """Removes 'circular_repeats' from logical partition spec."""
65114 if full_logical is None :
0 commit comments