Skip to content

Commit eb9d12f

Browse files
Merge pull request #3454 from CIeNET-International:charlesli/input_shardings_unittest
PiperOrigin-RevId: 886845057
2 parents a56b677 + 448b70f commit eb9d12f

3 files changed

Lines changed: 33 additions & 10 deletions

File tree

src/maxtext/utils/sharding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
_ACTIVATION_SHARDINGS_DUMP = []
3737

3838

39+
def clear_input_shardings_dump():
40+
"""Clear the input shardings dump"""
41+
_LOGGED_ACTIVATION_SHARDINGS.clear()
42+
_ACTIVATION_SHARDINGS_DUMP.clear()
43+
44+
3945
def get_input_data_sharding(config, mesh):
4046
"""Get the input data sharding for the model"""
4147
if config.enable_diloco:

tests/unit/sharding_compare_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
import jax.numpy as jnp
2222
from maxtext.configs import pyconfig
2323
from maxtext.utils import maxtext_utils
24+
from maxtext.utils.sharding import clear_input_shardings_dump
2425
# import optax
2526

2627
from maxtext.layers import quantizations
2728
from maxtext.models import models
2829
from maxtext.optimizers import optimizers
2930
from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config
30-
from tests.utils.sharding_dump import TEST_CASES, load_json, named_shardings_to_json, partition_specs_to_json
31+
from tests.utils.sharding_dump import TEST_CASES, load_json, input_sharding_to_json, named_shardings_to_json, partition_specs_to_json
3132
from tests.utils.test_helpers import get_test_config_path
3233
import pytest
3334

@@ -124,25 +125,34 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
124125
f"compile_topology={topology}",
125126
f"compile_topology_num_slices={num_slice}",
126127
f"model_name={model_name}",
128+
"log_config=false",
129+
"debug_sharding=true", # for input sharding dump
127130
]
128131

129132
root_dir = "tests/utils/sharding_info"
130133
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}")
131134

132135
named_json_path = os.path.join(base_path, "named_shardings.json")
133136
logical_json_path = os.path.join(base_path, "logical_shardings.json")
137+
input_json_path = os.path.join(base_path, "input_shardings.json")
134138

135139
if not os.path.exists(named_json_path):
136140
pytest.skip(f"Missing named_shardings.json for {model_name} {topology} slice {num_slice}")
137141
return
138142
if not os.path.exists(logical_json_path):
139143
pytest.skip(f"Missing logical_shardings.json for {model_name} {topology} slice {num_slice}")
140144
return
145+
if not os.path.exists(input_json_path):
146+
pytest.skip(f"Missing input_shardings.json for {model_name} {topology} slice {num_slice}")
147+
return
141148

142149
config = pyconfig.initialize(params)
143150
validate_config(config)
144151

152+
clear_input_shardings_dump()
145153
topology_mesh = get_topology_mesh(config)
154+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
155+
optimizers.get_optimizer(config, learning_rate_schedule)
146156
shaped_train_args, _, state_mesh_shardings, logical_shardings, _ = get_shaped_inputs(topology_mesh, config)
147157

148158
error_messages = []
@@ -173,6 +183,20 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
173183
compare_sharding_jsons(expected_logical, "Expected (Logical)", actual_logical, "Actual (Logical)")
174184
error_messages.append(f"Logical sharding mismatch for {model_name} on {topology} slice {num_slice}")
175185

186+
# 3. Compare Input Shardings
187+
actual_input = input_sharding_to_json()
188+
expected_input = load_json(input_json_path)
189+
# calculate checksum
190+
actual_input_sum = compute_checksum(actual_input)
191+
expected_input_sum = compute_checksum(expected_input)
192+
193+
input_match = actual_input_sum == expected_input_sum
194+
195+
if not input_match:
196+
print(f"\n[FAIL] Input Sharding Mismatch: {model_name} {topology} slice {num_slice}", flush=True)
197+
# compare_sharding_jsons(expected_input, "Expected (Input)", actual_input, "Actual (Input)")
198+
error_messages.append(f"Input sharding mismatch for {model_name} on {topology} slice {num_slice}")
199+
176200
assert not error_messages, "\n".join(error_messages)
177201

178202

tests/utils/sharding_dump.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -388,17 +388,10 @@ def partition_specs_to_json(logical_tree, shape_tree) -> dict[str, Any]:
388388
def input_sharding_to_json() -> dict[str, Any]:
389389
input_sharding = {}
390390
input_sharding["Activation Sharding Dump"] = _ACTIVATION_SHARDINGS_DUMP
391+
print(f"Got {len(_ACTIVATION_SHARDINGS_DUMP)} Input entries.")
391392
return input_sharding
392393

393394

394-
def save_activation_shading_dict(output_path: str | Path, sharding_dict: dict) -> None:
395-
"""Save the activation sharding dict directly to a JSON file."""
396-
output_path = Path(output_path)
397-
output_path.parent.mkdir(parents=True, exist_ok=True)
398-
with open(output_path, "w", encoding="utf-8") as f:
399-
json.dump(sharding_dict, f, indent=2)
400-
401-
402395
def save_json(output_path: str | Path, sharding_dict: dict) -> None:
403396
"""Save dict to a JSON file."""
404397
output_path = Path(output_path)
@@ -408,7 +401,7 @@ def save_json(output_path: str | Path, sharding_dict: dict) -> None:
408401

409402

410403
def load_json(json_path: str | Path) -> dict:
411-
"""Loads the named_shardings.json file into a plain Python dict."""
404+
"""Loads json file into a plain Python dict."""
412405
json_path = Path(json_path)
413406
with open(json_path, "r", encoding="utf-8") as f:
414407
return json.load(f)

0 commit comments

Comments
 (0)