Skip to content

Commit af82e24

Browse files
SurbhiJainUSCA9isha
authored andcommitted
Update import paths for maxtext_utils and pyconfig in tests
PiperOrigin-RevId: 877446791
1 parent b5f41ec commit af82e24

11 files changed

Lines changed: 122 additions & 273 deletions

File tree

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,33 +57,18 @@ jobs:
5757
- name: Install MaxText and Dependencies
5858
shell: bash
5959
run: |
60+
# 1. Create virtual environment
6061
python3 -m uv venv --seed
6162
source .venv/bin/activate
62-
63-
# Install MaxText package
6463
maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null)
65-
uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest
66-
uv pip install -r src/install_maxtext_extra_deps/extra_deps_from_github.txt
6764
68-
# Install dependencies for running notebooks
69-
uv pip install papermill ipykernel ipywidgets
65+
# 2. Install MaxText package and all the post training dependencies
66+
uv pip install ${maxtext_wheel}[tpu-post-train] --resolution=lowest
67+
#TODO: @mazumdera: replace this with the following after release
68+
# uv pip install maxtext[tpu-post-train] --resolution=lowest
69+
install_maxtext_tpu_post_train_extra_deps
7070
.venv/bin/python3 -m ipykernel install --user --name maxtext_venv
71-
72-
# Install Tunix for post-training notebooks
73-
git clone https://github.com/google/tunix
74-
uv pip install ./tunix
7571
76-
# Install vllm for post-training notebooks
77-
git clone https://github.com/vllm-project/vllm.git
78-
VLLM_TARGET_DEVICE="tpu" uv pip install ./vllm
79-
80-
# Install tpu-inference for post-training notebooks
81-
git clone https://github.com/vllm-project/tpu-inference.git
82-
uv pip install ./tpu-inference
83-
84-
uv pip install --no-deps qwix==0.1.4
85-
uv pip install --no-deps protobuf==5.29.5
86-
uv pip install math-verify==0.9.0
8772
python3 -m pip freeze
8873
- name: Run Post-Training Notebooks
8974
shell: bash

dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ antlr4-python3-runtime>=4.9.3
1111
anyio>=4.11.0
1212
aqtp>=0.9.0
1313
array_record>=0.8.3
14+
asttokens>=3.0.1
1415
astor>=0.8.1
1516
astroid>=4.0.2
1617
astunparse>=1.6.3
@@ -24,7 +25,7 @@ botocore>=1.42.56
2425
build>=1.3.0
2526
cachetools>=6.2.2
2627
cbor2>=5.8.0
27-
certifi>=2025.11.12
28+
certifi>=2026.2.25
2829
cffi>=2.0.0
2930
cfgv>=3.5.0
3031
charset-normalizer>=3.4.4
@@ -38,6 +39,7 @@ clu>=0.0.12
3839
cmake>=4.2.1
3940
colorama>=0.4.6
4041
colorful>=0.5.8
42+
comm>=0.2.3
4143
compressed-tensors>=0.13.0
4244
contourpy>=1.3.3
4345
coverage>=7.12.0
@@ -46,6 +48,7 @@ cycler>=0.12.1
4648
dacite>=1.9.2
4749
dataclasses-json>=0.6.7
4850
datasets>=4.6.0
51+
debugpy>=1.8.20
4952
decorator>=5.2.1
5053
depyf>=0.20.0
5154
dill>=0.4.0
@@ -60,13 +63,16 @@ editdistance>=0.8.1
6063
einops>=0.8.1
6164
einshape>=1.0
6265
email-validator>=2.3.0
66+
entrypoints>=0.4
6367
etils>=1.13.0
6468
evaluate>=0.4.6
6569
execnet>=2.1.2
70+
executing>=2.2.1
6671
fastapi>=0.122.0
6772
fastapi-cli>=0.0.24
6873
fastapi-cloud-cli>=0.13.0
6974
fastar>=0.8.0
75+
fastjsonschema>=2.21.2
7076
filelock>=3.20.0
7177
flatbuffers>=25.9.23
7278
flax>=0.12.4
@@ -126,13 +132,18 @@ importlib_metadata>=8.7.0
126132
importlib_resources>=6.5.2
127133
iniconfig>=2.3.0
128134
interegular>=0.3.3
135+
ipykernel>=7.2.0
136+
ipython>=9.10.0
137+
ipython_pygments_lexers>=1.1.1
138+
ipywidgets>=8.1.8
129139
isort>=7.0.0
130140
jaraco.classes>=3.4.0
131141
jaraco.context>=6.1.0
132142
jaraco.functools>=4.3.0
133143
jax>=0.8.3
134144
jaxlib>=0.8.3
135145
jaxtyping>=0.3.3
146+
jedi>=0.19.2
136147
jeepney>=0.9.0
137148
Jinja2>=3.1.6
138149
jiter>=0.13.0
@@ -141,6 +152,9 @@ joblib>=1.5.2
141152
jsonlines>=4.0.0
142153
jsonschema>=4.26.0
143154
jsonschema-specifications>=2025.9.1
155+
jupyter_client>=8.8.0
156+
jupyter_core>=5.9.1
157+
jupyterlab_widgets>=3.0.16
144158
kagglehub>=0.3.13
145159
keras>=3.12.0
146160
keyring>=25.7.0
@@ -162,6 +176,7 @@ MarkupSafe>=3.0.3
162176
marshmallow>=3.26.2
163177
math-verify>=0.9.0
164178
matplotlib>=3.10.7
179+
matplotlib-inline>=0.2.1
165180
mccabe>=0.7.0
166181
mcp>=1.26.0
167182
mdurl>=0.1.2
@@ -178,6 +193,8 @@ multidict>=6.7.0
178193
multiprocess>=0.70.18
179194
mypy_extensions>=1.1.0
180195
namex>=0.1.0
196+
nbclient>=0.10.4
197+
nbformat>=5.10.4
181198
nest-asyncio>=1.6.0
182199
networkx>=3.6
183200
ninja>=1.13.0
@@ -186,7 +203,7 @@ nltk>=3.9.2
186203
nodeenv>=1.9.1
187204
numba>=0.62.1
188205
numpy>=2.2.6
189-
numpy-typing-compat>=20250818.2.0
206+
numpy-typing-compat>=20250818.2.2
190207
nvidia-cublas-cu12>=12.8.4.1
191208
nvidia-cuda-cupti-cu12>=12.8.90
192209
nvidia-cuda-nvrtc-cu12>=12.8.93
@@ -218,7 +235,7 @@ opentelemetry-exporter-prometheus>=0.60b1
218235
opentelemetry-proto>=1.39.1
219236
opentelemetry-sdk>=1.39.1
220237
opentelemetry-semantic-conventions>=0.60b1
221-
opentelemetry-semantic-conventions-ai>=0.4.14
238+
opentelemetry-semantic-conventions-ai>=0.4.15
222239
opt_einsum>=3.4.0
223240
optax>=0.2.6
224241
optree>=0.18.0
@@ -228,23 +245,29 @@ orbax-export>=0.0.8
228245
outlines_core>=0.2.11
229246
packaging>=26.0
230247
pandas>=2.3.3
248+
papermill>=2.7.0
231249
parameterized>=0.9.0
250+
parso>=0.8.6
232251
partial-json-parser>=0.2.1.1.post7
233252
pathspec>=0.12.1
234253
pathwaysutils>=0.1.4
235254
perfetto>=0.16.0
255+
pexpect>=4.9.0
236256
pillow>=12.0.0
237-
platformdirs>=4.5.0
257+
platformdirs>=4.9.2
238258
pluggy>=1.6.0
239259
portpicker>=1.6.0
240260
pre_commit>=4.5.0
241261
prometheus-fastapi-instrumentator>=7.1.0
242262
prometheus_client>=0.23.1
243263
promise>=2.3
264+
prompt_toolkit>=3.0.52
244265
propcache>=0.4.1
245266
proto-plus>=1.26.1
246267
protobuf>=5.29.6
247-
psutil>=7.1.3
268+
psutil>=7.2.2
269+
ptyprocess>=0.7.0
270+
pure_eval>=0.2.3
248271
py-cpuinfo>=9.0.0
249272
py-spy>=0.4.1
250273
pyarrow>=22.0.0
@@ -316,18 +339,19 @@ sniffio>=1.3.1
316339
sortedcontainers>=2.4.0
317340
sse-starlette>=3.2.0
318341
starlette>=0.50.0
342+
stack-data>=0.6.3
319343
supervisor>=4.3.0
320344
sympy>=1.14.0
321345
tabulate>=0.9.0
322-
tenacity>=9.1.2
323-
tensorboard>=2.19.0
346+
tenacity>=9.1.4
347+
tensorboard>=2.20.0
324348
tensorboard-data-server>=0.7.2
325349
tensorboard-plugin-profile>=2.13.0
326350
tensorboardX>=2.6.4
327-
tensorflow>=2.19.1
351+
tensorflow>=2.20.0
328352
tensorflow-datasets>=4.9.9
329353
tensorflow-metadata>=1.17.2
330-
tensorflow-text>=2.19.0
354+
tensorflow-text>=2.20.0
331355
tensorstore>=0.1.79
332356
termcolor>=3.2.0
333357
tiktoken>=0.12.0
@@ -339,8 +363,10 @@ toolz>=1.1.0
339363
torch>=2.9.0
340364
torchax>=0.0.11
341365
torchvision>=0.24.0
366+
tornado>=6.5.4
342367
tpu-info>=0.7.1
343-
tqdm>=4.67.1
368+
tqdm>=4.67.3
369+
traitlets>=5.14.3
344370
transformers>=4.57.1
345371
treescope>=0.1.10
346372
triton>=3.5.0
@@ -358,9 +384,11 @@ uvloop>=0.22.1
358384
virtualenv>=20.35.4
359385
wadler_lindig>=0.1.7
360386
watchfiles>=1.1.1
387+
wcwidth>=0.6.0
361388
websockets>=15.0.1
362389
Werkzeug>=3.1.3
363390
wheel>=0.46.3
391+
widgetsnbextension>=4.0.15
364392
wrapt>=2.0.1
365393
xgrammar>=0.1.29
366394
xprof>=2.21.1

docs/tutorials/posttraining/rl.md

Lines changed: 19 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -44,73 +44,45 @@ Let's get started!
4444

4545
## Create virtual environment and Install MaxText dependencies
4646

47-
If you have already completed the
48-
[MaxText installation](../../install_maxtext.md), you can skip to the next
49-
section for post-training dependencies installations. Otherwise, please install
50-
`MaxText` using the following commands before proceeding.
51-
5247
```bash
53-
# 1. Clone the repository
54-
git clone https://github.com/AI-Hypercomputer/maxtext.git
55-
cd maxtext
56-
57-
# 2. Create virtual environment
48+
# Create a virtual environment
5849
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
5950
pip install uv
6051
uv venv --python 3.12 --seed $VENV_NAME
6152
source $VENV_NAME/bin/activate
62-
63-
# 3. Install dependencies in editable mode
64-
uv pip install -e .[tpu] --resolution=lowest
65-
install_maxtext_github_deps
6653
```
6754

68-
## Install Post-Training dependencies
55+
### Option 1: From PyPI releases (Recommended)
6956

70-
### Option 1: From PyPI releases
57+
Run the following commands to get all the necessary installations.
7158

72-
> **Caution:** RL in MaxText is currently broken with PyPI releases of
73-
> post-training dependencies. We are working on fixing this and recommend
74-
> following [Option 2: From Github](#option-2-from-github) in the meantime.
59+
```bash
60+
uv pip install maxtext[tpu-post-train] --resolution=lowest
61+
install_maxtext_tpu_post_train_extra_deps
62+
```
7563

76-
Next, run the following bash script to get all the necessary installations
77-
inside the virtual environment (for e.g., `maxtext_venv`). This will take few
78-
minutes. Follow along the installation logs and look out for any issues!
64+
It installs MaxText and then for post-training, it installs primarily the following:
7965

80-
```
81-
bash tools/setup/setup_post_training_requirements.sh
82-
```
66+
a. [Tunix](https://github.com/google/tunix) as the LLM Post-Training Library, and
8367

84-
Primarily, it installs `Tunix`, and `vllm-tpu` which is
68+
b. `vllm-tpu` which is
8569
[vllm](https://github.com/vllm-project/vllm) and
8670
[tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby
8771
providing TPU inference for vLLM, with unified JAX and PyTorch support.
8872

8973
### Option 2: From Github
9074

91-
You can also locally git clone [tunix](https://github.com/google/tunix) and
92-
install using the instructions
93-
[here](https://github.com/google/tunix?tab=readme-ov-file#installation).
94-
Similarly install [vllm](https://github.com/vllm-project/vllm) and
95-
[tpu-inference](https://github.com/vllm-project/tpu-inference) from source
96-
following the instructions
97-
[here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source).
98-
To get a set of compatible commit IDs for `maxtext`, `tunix`, `tpu-inference`,
99-
and `vllm`, follow these steps:
100-
101-
1. Navigate to the
102-
[MaxText Package Tests](https://github.com/AI-Hypercomputer/maxtext/actions/workflows/build_and_test_maxtext.yml?query=event%3Aschedule)
103-
GitHub Actions workflow.
104-
105-
2. Select the latest successful run.
75+
For using a version newer than the latest PyPI release, you could also install the latest vetted versions of the dependencies from MaxText in the following way:
10676

107-
3. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job.
108-
109-
4. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`,
110-
`tpu-inference`, and `vllm` that were used in that successful run are listed
111-
in the logs of this step.
77+
```bash
78+
# 1. Clone the repository
79+
git clone https://github.com/AI-Hypercomputer/maxtext.git
80+
cd maxtext
11281

113-
5. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout <commit_id>`.
82+
# 2. Install dependencies in editable mode
83+
uv pip install -e .[tpu-post-train] --resolution=lowest
84+
install_maxtext_tpu_post_train_extra_deps
85+
```
11486

11587
## Setup environment variables
11688

docs/tutorials/posttraining/sft.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,21 @@ We use [Tunix](https://github.com/google/tunix), a JAX-based library designed fo
2424

2525
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
2626

27-
## Install dependencies
27+
## Install MaxText and Post-Training dependencies
2828

29-
```sh
30-
# 1. Clone the repository
31-
git clone https://github.com/AI-Hypercomputer/maxtext.git
32-
cd maxtext
33-
34-
# 2. Create virtual environment
29+
```bash
30+
# Create a virtual environment
3531
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
3632
pip install uv
3733
uv venv --python 3.12 --seed $VENV_NAME
3834
source $VENV_NAME/bin/activate
35+
```
36+
37+
Run the following commands to get all the necessary installations.
3938

40-
# 3. Install dependencies in editable mode
41-
uv pip install -e .[tpu] --resolution=lowest
42-
bash tools/setup/setup_post_training_requirements.sh
39+
```bash
40+
uv pip install maxtext[tpu-post-train] --resolution=lowest
41+
install_maxtext_tpu_post_train_extra_deps
4342
```
4443

4544
## Setup environment variables

src/install_maxtext_extra_deps/install_post_train_extra_deps.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def main():
3535
"""
3636
script_dir = Path(__file__).resolve().parent
3737

38-
os.environ['VLLM_TARGET_DEVICE'] = 'tpu'
38+
os.environ["VLLM_TARGET_DEVICE"] = "tpu"
3939

4040
# Adjust this path if your extra_post_train_deps_from_github.txt is in a different location,
4141
# e.g., script_dir / "data" / "extra_post_train_deps_from_github.txt"
@@ -67,12 +67,12 @@ def main():
6767

6868
local_vllm_install_command = [
6969
sys.executable, # Use the current Python executable's pip to ensure the correct environment
70-
"-m",
71-
"uv",
72-
"pip",
73-
"install",
74-
"src/maxtext/integration/vllm",
75-
"--no-deps",
70+
"-m",
71+
"uv",
72+
"pip",
73+
"install",
74+
"src/maxtext/integration/vllm", # MaxText on vllm installations
75+
"--no-deps",
7676
]
7777

7878
print(f"Installing extra dependencies from '{extra_deps_file}' using uv...")

0 commit comments

Comments
 (0)