Skip to content

Commit a07af41

Browse files
committed
Add post-training dependencies to MaxText
1 parent 17d805e commit a07af41

8 files changed

Lines changed: 486 additions & 248 deletions

File tree

dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 366 additions & 245 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ allow-direct-references = true
4141
[tool.hatch.build.targets.wheel]
4242
packages = ["src/MaxText", "src/maxtext", "src/install_maxtext_extra_deps"]
4343

44-
[tool.hatch.build.targets.wheel.hooks.custom]
45-
path = "build_hooks.py"
44+
# TODO: Add this hook back when it handles device-type parsing
45+
# [tool.hatch.build.targets.wheel.hooks.custom]
46+
# path = "build_hooks.py"
4647

4748
[project.scripts]
48-
install_maxtext_github_deps = "install_maxtext_extra_deps.install_github_deps:main"
49+
install_maxtext_tpu_github_deps = "install_maxtext_extra_deps.install_github_deps:main"
50+
install_maxtext_cuda12_github_deps = "install_maxtext_extra_deps.install_github_deps:main"
51+
install_maxtext_tpu_post_train_github_deps = "install_maxtext_extra_deps.install_post_train_github_deps:main"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
2+
google-tunix @ https://github.com/google/tunix/archive/b12123542511e45814ec3ff798ec4fb30b0b48e8.zip
3+
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
4+
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/eb2bd54c70f443775d8870caf56b78ab0b671269.zip
5+
vllm @ git+https://github.com/vllm-project/vllm@3bbb2046ff320395c80c139e55e7c1947c3fb5e1
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Installs extra dependencies from a requirements file using uv.
16+
17+
This script is designed to be run to install dependencies specified in
18+
'extra_post_train_deps_from_github.txt', which is expected to be in the same directory.
19+
It first ensures 'uv' is installed and then uses it to install the packages
20+
listed in the requirements file.
21+
"""
22+
23+
import os
24+
import subprocess
25+
import sys
26+
from pathlib import Path
27+
28+
29+
def main():
30+
"""
31+
Installs extra dependencies specified in extra_post_train_deps_from_github.txt using uv.
32+
33+
This script looks for 'extra_post_train_deps_from_github.txt' relative to its own location.
34+
It executes 'uv pip install -r <path_to_extra_deps.txt> --resolution=lowest'.
35+
"""
36+
script_dir = Path(__file__).resolve().parent
37+
38+
os.environ['VLLM_TARGET_DEVICE'] = 'tpu'
39+
40+
# Adjust this path if your extra_post_train_deps_from_github.txt is in a different location,
41+
# e.g., script_dir / "data" / "extra_post_train_deps_from_github.txt"
42+
extra_deps_file = script_dir / "extra_post_train_deps_from_github.txt"
43+
44+
if not extra_deps_file.exists():
45+
print(f"Error: '{extra_deps_file}' not found.")
46+
print("Please ensure 'extra_post_train_deps_from_github.txt' is in the correct location relative to the script.")
47+
sys.exit(1)
48+
# Check if 'uv' is available in the environment
49+
try:
50+
subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True)
51+
subprocess.run([sys.executable, "-m", "uv", "--version"], check=True, capture_output=True)
52+
except subprocess.CalledProcessError as e:
53+
print(f"Error checking uv version: {e}")
54+
print(f"Stderr: {e.stderr.decode()}")
55+
sys.exit(1)
56+
57+
command = [
58+
sys.executable, # Use the current Python executable's pip to ensure the correct environment
59+
"-m",
60+
"uv",
61+
"pip",
62+
"install",
63+
"-r",
64+
str(extra_deps_file),
65+
"--no-deps",
66+
]
67+
68+
print(f"Installing extra dependencies from '{extra_deps_file}' using uv...")
69+
print(f"Running command: {' '.join(command)}")
70+
71+
try:
72+
# Run the command
73+
process = subprocess.run(command, check=True, capture_output=True, text=True)
74+
print("Extra dependencies installed successfully!")
75+
print("--- Output from uv ---")
76+
print(process.stdout)
77+
if process.stderr:
78+
print("--- Errors/Warnings from uv (if any) ---")
79+
print(process.stderr)
80+
except subprocess.CalledProcessError as e:
81+
print("Failed to install extra dependencies.")
82+
print(f"Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}.")
83+
print("--- Stderr ---")
84+
print(e.stderr)
85+
print("--- Stdout ---")
86+
print(e.stdout)
87+
sys.exit(e.returncode)
88+
except (OSError, FileNotFoundError) as e:
89+
print(f"An OS-level error occurred while trying to run uv: {e}")
90+
sys.exit(1)
91+
92+
93+
if __name__ == "__main__":
94+
main()

tests/unit/distillation_checkpointing_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
"""Unit tests for Distillation Checkpointing logic."""
1616

17+
import pytest
18+
19+
pytest.importorskip("tunix")
20+
pytestmark = [pytest.mark.tpu_only]
21+
1722
import json
1823
import os
1924
import shutil

tests/unit/sft_hooks_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests for training and data loading hooks for SFT"""
1616
import pytest
1717

18+
pytest.importorskip("tunix")
1819
pytestmark = [pytest.mark.tpu_only, pytest.mark.external_training]
1920

2021
import jax

tests/unit/sharding_compare_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na
111111
return has_diff
112112

113113

114+
# Requires JAX TPU support to generate the simulated TPU topology.
115+
@pytest.mark.tpu_only
114116
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
115117
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
116118
"""
@@ -217,6 +219,8 @@ def abstract_state_and_shardings(request):
217219
class TestGetAbstractState:
218220
"""Test class for get_abstract_state function and sharding comparison."""
219221

222+
# Requires JAX TPU support to generate the simulated TPU topology.
223+
@pytest.mark.tpu_only
220224
def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pylint: disable=redefined-outer-name
221225
"""Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
222226

tests/unit/train_distill_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
"""Unit tests for the Distillation Trainer."""
1717

18+
import pytest
19+
20+
pytest.importorskip("tunix")
21+
pytestmark = [pytest.mark.tpu_only]
22+
1823
import shutil
1924
import tempfile
2025
import unittest

0 commit comments

Comments
 (0)