Skip to content

Commit f7971f2

Browse files
committed
Plumb config.max_segments_per_seq to grain dataset max_sequences_per_bin
1 parent 7991534 commit f7971f2

14 files changed

Lines changed: 35 additions & 14 deletions

File tree

src/MaxText/configs/a3/llama_3.1_405b/128vm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ python3 -m MaxText.$EXECUTABLE ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src
5252
dcn_fsdp_parallelism=128 \
5353
ici_fsdp_parallelism=8 \
5454
base_output_directory=$OUTPUT_PATH \
55+
max_segments_per_seq=32 \
5556
profiler=xplane
56-

src/MaxText/configs/base.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,10 @@ generate_padding_batch_train: False
552552
generate_padding_batch_eval: False
553553
# Maximum number of segments that can be packed into a single sequence
554554
# This needs to be passed to TransformerEngine's DotProductAttention layer for packing
555-
# NOTE: This parameter is only relevant for GPU packed attention.
556-
# When using TPU, this parameter will be ignored.
557-
max_segments_per_seq: 32
555+
# This also affects packing for grain, since TransformerEngine may crash or cause
556+
# data corruption if there are more segments packed than specified
557+
# Set this to something like 32 for GPUs when using TransformerEngine
558+
max_segments_per_seq: -1
558559
# Rampup batch size, similar to Megatron-LM, see
559560
# https://github.com/NVIDIA/Megatron-LM/blob/2a01637aa54ccdaf7ea9afc1f1b80f58c53d7f3c/megatron/core/num_microbatches_calculator.py#L233-L237
560561
# The ramp-up proceeds in stages from `per_device_batch_size_start` up to

src/MaxText/configs/gpu_smoke_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ per_device_batch_size: 2
1212
max_target_length: 1024
1313
dataset_type: "synthetic"
1414
steps: 10
15+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/llama2_70b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ logits_dot_in_fp32: False
1616

1717
per_device_batch_size: 6
1818
max_target_length: 4096
19+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/llama2_7b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ remat_policy: "minimal_with_context"
1212
use_iota_embed: True
1313
scan_layers: False
1414
dataset_type: "synthetic"
15-
async_checkpointing: False
15+
async_checkpointing: False
16+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/llama3.1_405b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ async_checkpointing: False
1414
logits_dot_in_fp32: False
1515
per_device_batch_size: 1.0
1616
max_target_length: 4096
17+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/llama3_70b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ use_iota_embed: True
2828
dataset_type: "synthetic"
2929
reuse_example_batch: 1
3030
enable_checkpointing: False
31+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/llama3_8b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ use_iota_embed: True
2828
dataset_type: "synthetic"
2929
reuse_example_batch: 1
3030
enable_checkpointing: False
31+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/mixtral_8x1b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ reuse_example_batch: 1
3333
enable_checkpointing: False
3434
megablox: False
3535
sparse_matmul: False
36+
max_segments_per_seq: 32

src/MaxText/configs/models/gpu/mixtral_8x2b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ reuse_example_batch: 1
3333
enable_checkpointing: False
3434
megablox: False
3535
sparse_matmul: False
36+
max_segments_per_seq: 32

0 commit comments

Comments
 (0)