Skip to content

Commit f27ac67

Browse files
Merge pull request #2926 from AI-Hypercomputer:chengnuojin-debug-sharding
PiperOrigin-RevId: 855366909
2 parents 6801c95 + 8cca866 commit f27ac67

16 files changed

Lines changed: 65 additions & 33 deletions

src/MaxText/data_loader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from jax.experimental import checkify
2121

2222
from MaxText import exceptions
23-
from MaxText.sharding import get_input_data_sharding, maybe_shard_with_name
23+
from MaxText.sharding import get_input_data_sharding
2424
from MaxText.utils.goodput_utils import (
2525
GoodputEvent,
2626
maybe_record_goodput,
@@ -70,10 +70,9 @@ def load_next_batch_pre_sharding(self):
7070

7171
def load_next_batch(self, *args, **kwargs):
7272
"""Loads the next batch with sharding hint"""
73-
return maybe_shard_with_name(
73+
return jax.device_put(
7474
self.load_next_batch_pre_sharding(),
7575
self.input_data_shardings,
76-
self.config.shard_mode,
7776
)
7877

7978
def check_example_batch(self):
@@ -154,7 +153,7 @@ def _slice(data):
154153
self.buffer_start = slice_end
155154
output = jax.tree.map(_slice, self.batch_buffer)
156155
self.rampup_active = rampup_manager.update()
157-
return maybe_shard_with_name(output, self.input_data_shardings, self.config.shard_mode)
156+
return jax.device_put(output, self.input_data_shardings)
158157

159158

160159
def create_dataloader(config, mesh, data_iterator, goodput_recorder, rampup_manager):

src/MaxText/gradient_accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def gradient_accumulation_loss_and_grad(
6565

6666
def _maybe_shard_with_name(inputs, sharding_names):
6767
"""Wrapper of maybe_shard_with_name with fixed shard_mode"""
68-
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode)
68+
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding)
6969

7070
# For more efficient DP/ZeRO-1 + GA
7171
if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:

src/MaxText/layers/attention_op.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,17 +1198,10 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1):
11981198
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
11991199
else:
12001200
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
1201-
elif (
1202-
self.config.use_jax_splash
1203-
and self.config.expert_shard_attention_option == EP_AS_FSDP
1204-
):
1201+
elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP:
12051202
if self.config.use_max_logit_estimate > 0:
1206-
sa_config = dataclasses.replace(
1207-
sa_config, max_logit_const=self.config.use_max_logit_estimate
1208-
)
1209-
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((
1210-
Q_LENGTH_NO_EXP,
1211-
))
1203+
sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate)
1204+
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
12121205
else:
12131206
# Create multi-head mask
12141207
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
@@ -1327,7 +1320,13 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
13271320
if pspec is None:
13281321
return None
13291322
sharding = NamedSharding(self.mesh, pspec)
1330-
return maybe_shard_with_name(inputs, sharding, shard_mode=self.config.shard_mode)
1323+
return maybe_shard_with_name(
1324+
inputs,
1325+
sharding,
1326+
shard_mode=self.config.shard_mode,
1327+
debug_sharding=self.config.debug_sharding,
1328+
extra_stack_level=1,
1329+
)
13311330

13321331
query = _maybe_shard_with_pspec(query, axis_names_q)
13331332
key = _maybe_shard_with_pspec(key, axis_names_kv)

src/MaxText/layers/attentions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def __init__(
525525
maybe_shard_with_logical,
526526
mesh=mesh,
527527
shard_mode=config.shard_mode,
528+
debug_sharding=config.debug_sharding,
528529
)
529530

530531
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:

src/MaxText/layers/decoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __call__(
9898
sharding.maybe_shard_with_logical,
9999
mesh=mesh,
100100
shard_mode=cfg.shard_mode,
101+
debug_sharding=cfg.debug_sharding,
101102
)
102103

103104
if self.model_mode == MODEL_MODE_PREFILL:

src/MaxText/layers/deepseek.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ def mlp_op(self, x, deterministic, *args, **kwargs):
129129

130130
def with_logical_constraint(self, x):
131131
return maybe_shard_with_logical(
132-
x, logical_axes=self.logical_axis_names, mesh=self.mesh, shard_mode=self.config.shard_mode
132+
x,
133+
logical_axes=self.logical_axis_names,
134+
mesh=self.mesh,
135+
shard_mode=self.config.shard_mode,
136+
debug_sharding=self.config.debug_sharding,
133137
)
134138

135139
def dropout_op(self, x, deterministic):

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def mlp_logical_axis_names(self):
179179
def with_logical_constraint(self, x):
180180
return maybe_shard_with_logical(
181181
x, logical_axes=self.logical_axis_names,
182-
mesh=self.mesh, shard_mode=self.config.shard_mode
182+
mesh=self.mesh, shard_mode=self.config.shard_mode, debug_sharding=self.config.debug_sharding,
183183
)
184184

185185
def pre_attention_norm_op(self, x):

src/MaxText/layers/linears.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def __init__(
459459
maybe_shard_with_logical,
460460
mesh=mesh,
461461
shard_mode=config.shard_mode,
462+
debug_sharding=config.debug_sharding,
462463
)
463464

464465
def get_norm_layer(self, num_features: int):

src/MaxText/layers/llama2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
maybe_shard_with_logical,
135135
mesh=self.mesh,
136136
shard_mode=config.shard_mode,
137+
debug_sharding=config.debug_sharding,
137138
)
138139

139140
def __call__(

src/MaxText/layers/moe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,13 @@ def __init__(
444444
self.wo_bias = None
445445

446446
def _maybe_shard_with_logical(self, inputs, logical_name):
447-
return maybe_shard_with_logical(inputs, logical_name, mesh=self.mesh, shard_mode=self.config.shard_mode)
447+
return maybe_shard_with_logical(
448+
inputs,
449+
logical_name,
450+
mesh=self.mesh,
451+
shard_mode=self.config.shard_mode,
452+
debug_sharding=self.config.debug_sharding,
453+
)
448454

449455
def _logical_to_mesh_axes(self, logical_name):
450456
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)

0 commit comments

Comments
 (0)