Skip to content

Commit b6e0cdb

Browse files
Merge pull request #3120 from AI-Hypercomputer:sanbao/xlml
PiperOrigin-RevId: 868894248
2 parents 50ef2df + a2fdf9e commit b6e0cdb

6 files changed

Lines changed: 7 additions & 7 deletions

File tree

tests/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/
6060
python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32
6161

6262
# Run supervised fine-tuning - megablox implementation
63-
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32
63+
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32
6464

6565
# Run decoding - megablox implementation
6666
# Note decode requires the access token for huggingface tokenizer even if the model is not gated

tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/
6464
python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
6565

6666
# Run supervised fine-tuning - megablox implementation
67-
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
67+
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
6868

6969
# Run decoding - megablox implementation
7070
# Note decode requires the access token for huggingface tokenizer even if the model is not gated

tests/end_to_end/tpu/llama3.1/8b/run_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ fi
5757
echo "Running fine-tuning on checkpoint: ${PRE_TRAINED_MODEL_CKPT_PATH}"
5858

5959
# Run Supervised Fine-Tuning on MaxText checkpoint using HuggingFaceH4/ultrachat_200k dataset
60-
python3 -m maxtext.trainers.post_train.sft.train_sft "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/sft.yml \
60+
python3 -m maxtext.trainers.post_train.sft.train_sft "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"/sft.yml \
6161
run_name=${RUN_NAME} base_output_directory=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL} \
6262
model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
6363
hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \

tests/end_to_end/tpu/run_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ fi
5959
echo "Running fine-tuning on checkpoint: ${PRE_TRAINED_MODEL_CKPT_PATH}"
6060

6161
# Run Supervised Fine-Tuning on MaxText checkpoint using HuggingFaceH4/ultrachat_200k dataset
62-
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft.yml \
62+
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml \
6363
run_name=${RUN_NAME} base_output_directory=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL} \
6464
model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
6565
hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \

tests/end_to_end/tpu/test_sft_trainer.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ PER_DEVICE_BATCH_SIZE=1
1919
LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass
2020

2121
# SFT with HF pipeline
22-
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/sft.yml \
22+
python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"/sft.yml \
2323
run_name=${RUN_NAME}-hf base_output_directory=${BASE_OUTPUT_DIRECTORY} \
2424
model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
2525
dataset_type=hf hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \
@@ -45,7 +45,7 @@ largest_dir="${sorted_dirs[-1]}"
4545
FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/items
4646

4747
# Decode
48-
python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/sft.yml \
48+
python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"/sft.yml \
4949
run_name=${RUN_NAME}-hf-decode \
5050
model_name=${PRE_TRAINED_MODEL} tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} tokenizer_type=huggingface \
5151
load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \

tests/integration/sft_trainer_correctness_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_golden_data(model_name):
5757
def initialize_config():
5858
"""Initialize configurations."""
5959
return pyconfig.initialize(
60-
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "sft.yml")],
60+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs/post_train", "sft.yml")],
6161
run_name="test-sft-trainer-correctness",
6262
model_name="default",
6363
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer"),

0 commit comments

Comments
 (0)