Skip to content

Commit ec8aa5f

Browse files
committed
initial commit for seed-env integration
1 parent 6e3b58b commit ec8aa5f

7 files changed

Lines changed: 414 additions & 19 deletions

File tree

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
absl-py
3+
aqtp
4+
datasets
5+
einops
6+
flax
7+
ftfy
8+
google-cloud-storage
9+
grain
10+
hf_transfer
11+
huggingface_hub
12+
imageio-ffmpeg
13+
imageio
14+
jax
15+
jaxlib
16+
Jinja2
17+
opencv-python-headless
18+
optax
19+
orbax-checkpoint
20+
parameterized
21+
Pillow
22+
pyink
23+
pylint
24+
pytest
25+
ruff
26+
scikit-image
27+
sentencepiece
28+
tensorboard-plugin-profile
29+
tensorboard
30+
tensorboardx
31+
tensorflow-datasets
32+
tensorflow
33+
tokamax
34+
tokenizers
35+
transformers
36+
37+
# pinning torch and torchvision to specific versions to avoid
38+
# installing GPU versions from PyPI when running seed-env
39+
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
40+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
41+
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip
42+
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Generated by seed-env. Do not edit manually.
2+
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.
3+
4+
absl-py>=2.3.1
5+
aiofiles>=25.1.0
6+
aiohappyeyeballs>=2.6.1
7+
aiohttp>=3.13.3
8+
aiosignal>=1.4.0
9+
aqtp>=0.9.0
10+
array-record>=0.8.3 ; sys_platform != 'win32'
11+
astroid>=4.0.4
12+
astunparse>=1.6.3
13+
attrs>=25.4.0
14+
auditwheel>=6.6.0
15+
black>=25.12.0
16+
build>=1.4.0
17+
certifi>=2026.1.4
18+
cffi>=2.0.0 ; platform_python_implementation != 'PyPy'
19+
charset-normalizer>=3.4.4
20+
cheroot>=11.1.2
21+
chex>=0.1.91
22+
click>=8.3.1
23+
cloudpickle>=3.1.2
24+
colorama>=0.4.6
25+
contourpy>=1.3.3
26+
cryptography>=46.0.5
27+
cycler>=0.12.1
28+
dataclasses-json>=0.6.7
29+
datasets>=2.14.4
30+
decorator>=5.2.1
31+
dill>=0.3.7
32+
dm-tree>=0.1.9
33+
docstring-parser>=0.17.0
34+
einops>=0.8.2
35+
etils>=1.13.0
36+
execnet>=2.1.2
37+
filelock>=3.20.3
38+
flatbuffers>=25.12.19
39+
flax>=0.12.4
40+
fonttools>=4.61.1
41+
frozenlist>=1.8.0
42+
fsspec>=2026.1.0
43+
ftfy>=6.3.1
44+
gast>=0.7.0
45+
gcsfs>=2026.1.0
46+
google-api-core>=2.29.0
47+
google-auth-oauthlib>=1.2.4
48+
google-auth>=2.48.0
49+
google-cloud-core>=2.5.0
50+
google-cloud-storage-control>=1.10.0
51+
google-cloud-storage>=3.9.0
52+
google-crc32c>=1.8.0
53+
google-pasta>=0.2.0
54+
google-resumable-media>=2.8.0
55+
googleapis-common-protos>=1.72.0
56+
grain>=0.2.15
57+
grpc-google-iam-v1>=0.14.3
58+
grpcio-status>=1.76.0
59+
grpcio>=1.76.0
60+
gviz-api>=1.10.0
61+
h5py>=3.15.1
62+
hf-transfer>=0.1.9
63+
hf-xet>=1.2.1 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
64+
huggingface-hub>=0.36.2
65+
humanize>=4.15.0
66+
hypothesis>=6.142.1
67+
idna>=3.11
68+
imageio-ffmpeg>=0.6.0
69+
imageio>=2.37.2
70+
immutabledict>=4.3.0
71+
importlib-resources>=6.5.2
72+
iniconfig>=2.3.0
73+
isort>=7.0.0
74+
jaraco-functools>=4.4.0
75+
jax>=0.9.0
76+
jaxlib>=0.9.0
77+
jaxtyping>=0.3.7
78+
jinja2>=3.1.6
79+
keras>=3.13.1
80+
kiwisolver>=1.4.9
81+
lazy-loader>=0.4
82+
libclang>=18.1.1
83+
libtpu>=0.0.34 ; platform_machine == 'x86_64' and sys_platform == 'linux'
84+
markdown-it-py>=4.0.0
85+
markdown>=3.10.1
86+
markupsafe>=3.0.3
87+
marshmallow>=3.26.2
88+
matplotlib>=3.10.8
89+
mccabe>=0.7.0
90+
mdurl>=0.1.2
91+
ml-dtypes>=0.5.4
92+
more-itertools>=10.8.0
93+
mpmath>=1.3.0
94+
msgpack>=1.1.2
95+
multidict>=6.7.1
96+
multiprocess>=0.70.15
97+
mypy-extensions>=1.1.0
98+
namex>=0.1.0
99+
nest-asyncio>=1.6.0
100+
networkx>=3.6.1
101+
numpy-typing-compat>=20251206.2.0
102+
numpy>=2.0.2
103+
nvidia-cuda-cccl>=13.1.115
104+
oauthlib>=3.3.1
105+
opencv-python-headless>=4.13.0.92
106+
opt-einsum>=3.4.0
107+
optax>=0.2.6
108+
optree>=0.18.0
109+
optype>=0.15.0
110+
orbax-checkpoint>=0.11.32
111+
orbax-export>=0.0.8
112+
packaging>=26.0
113+
pandas>=3.0.0
114+
parameterized>=0.9.0
115+
pathspec>=1.0.4
116+
pillow>=12.1.0
117+
platformdirs>=4.7.1
118+
pluggy>=1.6.0
119+
portpicker>=1.6.0
120+
promise>=2.3
121+
propcache>=0.4.1
122+
proto-plus>=1.27.1
123+
protobuf>=6.33.5
124+
psutil>=7.2.1
125+
pyarrow>=23.0.0
126+
pyasn1-modules>=0.4.2
127+
pyasn1>=0.6.2
128+
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
129+
pyelftools>=0.32
130+
pygments>=2.19.2
131+
pyink>=25.12.0
132+
pylint>=4.0.4
133+
pyparsing>=3.3.2
134+
pyproject-hooks>=1.2.0
135+
pytest-xdist>=3.8.0
136+
pytest>=8.4.2
137+
python-dateutil>=2.9.0.post0
138+
pytokens>=0.4.1
139+
pyyaml>=6.0.3
140+
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip
141+
regex>=2026.1.15
142+
requests-oauthlib>=2.0.0
143+
requests>=2.32.5
144+
rich>=14.2.0
145+
rsa>=4.9.1
146+
ruff>=0.15.1
147+
safetensors>=0.7.0
148+
scikit-image>=0.26.0
149+
scipy-stubs>=1.17.0.1
150+
scipy>=1.17.0
151+
sentencepiece>=0.2.1
152+
setuptools>=80.10.1
153+
simple-parsing>=0.1.8
154+
simplejson>=3.20.2
155+
six>=1.17.0
156+
sortedcontainers>=2.4.0
157+
sympy>=1.14.0
158+
tensorboard-data-server>=0.7.2
159+
tensorboard-plugin-profile>=2.21.6
160+
tensorboard>=2.20.0
161+
tensorboardx>=2.6.4
162+
tensorflow-datasets>=4.9.9
163+
tensorflow-metadata>=1.17.3
164+
tensorflow>=2.20.0
165+
tensorstore>=0.1.80
166+
termcolor>=3.3.0
167+
tifffile>=2026.1.28
168+
tokamax>=0.1.0
169+
tokenizers>=0.22.2
170+
toml>=0.10.2
171+
tomlkit>=0.14.0
172+
toolz>=1.1.0
173+
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
174+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
175+
tqdm>=4.67.3
176+
transformers>=4.57.6
177+
treescope>=0.1.10
178+
typing-extensions>=4.15.0
179+
typing-inspect>=0.9.0
180+
tzdata>=2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
181+
urllib3>=2.6.3
182+
wadler-lindig>=0.1.7
183+
wcwidth>=0.6.0
184+
werkzeug>=3.1.5
185+
wheel>=0.46.2
186+
wrapt>=2.1.1
187+
xprof>=2.21.6
188+
xxhash>=3.6.0
189+
yarl>=1.22.0
190+
zipp>=3.23.0
191+
zstandard>=0.25.0

pyproject.toml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,55 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
[build-system]
16+
requires = ["hatchling", "hatch-requirements-txt"]
17+
build-backend = "hatchling.build"
18+
19+
[tool.hatch.version]
20+
path = "src/maxdiffusion/__init__.py"
21+
22+
[project]
23+
name = "maxdiffusion"
24+
dynamic = ["version", "optional-dependencies"]
25+
requires-python = ">=3.12"
26+
readme = "README.md"
27+
license = "Apache-2.0"
28+
classifiers = [
29+
"Programming Language :: Python",
30+
]
31+
dependencies = []
32+
33+
[tool.hatch.metadata.hooks.requirements_txt.optional-dependencies]
34+
tpu = ["dependencies/requirements/generated_requirements/tpu-requirements.txt"]
35+
cuda12 = ["dependencies/requirements/generated_requirements/cuda12-requirements.txt"]
36+
37+
[project.urls]
38+
Repository = "https://github.com/AI-Hypercomputer/maxdiffusion.git"
39+
"Bug Tracker" = "https://github.com/AI-Hypercomputer/maxdiffusion/issues"
40+
41+
[tool.hatch.metadata]
42+
allow-direct-references = true
43+
44+
[tool.hatch.build.targets.wheel]
45+
packages = ["src/maxdiffusion", "src/install_maxdiffusion_extra_deps"]
46+
47+
[tool.hatch.build.targets.wheel.hooks.custom]
48+
path = "build_hooks.py"
49+
50+
[project.scripts]
51+
install_maxdiffusion_github_deps = "install_maxdiffusion_extra_deps.install_github_deps:main"
52+
153
[tool.ruff]
254
# Never enforce `E501` (line length violations).
355
ignore = ["C901", "E501", "E741", "F402", "F823", "E402", "I001"]

setup.sh

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th
3535
if [[ $REPLY =~ ^[Yy]$ ]]; then
3636
# Check if uv is installed first; if not, install uv
3737
if ! command -v uv &> /dev/null; then
38-
echo -e "\n'uv' command not found. Installing it now via the official installer..."
39-
curl -LsSf https://astral.sh/uv/install.sh | sh
38+
# echo -e "\n'uv' command not found. Installing it now via the official installer..."
39+
# curl -LsSf https://astral.sh/uv/install.sh | sh
4040

41-
echo -e "\n\e[33m'uv' has been installed.\e[0m"
42-
echo "The installer likely printed instructions to update your shell's PATH."
43-
echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script."
44-
exit 1
41+
# echo -e "\n\e[33m'uv' has been installed.\e[0m"
42+
# echo "The installer likely printed instructions to update your shell's PATH."
43+
# echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script."
44+
# exit 1
45+
pip install uv
4546
fi
4647
maxdiffusion_dir=$(pwd)
4748
cd
@@ -53,7 +54,7 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th
5354
echo "No name provided. Using default name: '$venv_name'"
5455
fi
5556
echo "Creating virtual environment '$venv_name' with Python 3.12..."
56-
uv venv --python 3.12 "$venv_name" --seed
57+
python3 -m uv venv --python 3.12 "$venv_name" --seed
5758
printf '%s\n' "$(realpath -- "$venv_name")" >> /tmp/venv_created
5859
echo -e "\n\e[32mVirtual environment '$venv_name' created successfully!\e[0m"
5960
echo "To activate it, run the following command:"
@@ -81,6 +82,8 @@ apt update -y && apt -y install gcsfuse
8182
rm -rf /var/lib/apt/lists/*
8283
EOF
8384

85+
python3 -m pip install -U setuptools wheel uv
86+
8487
# Set environment variables from command line arguments
8588
for ARGUMENT in "$@"; do
8689
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
@@ -104,7 +107,7 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
104107
fi
105108

106109
# Install dependencies from requirements.txt first
107-
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
110+
python3 -m uv pip install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
108111

109112
# Install JAX and JAXlib based on the specified mode
110113
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
@@ -113,23 +116,23 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
113116
echo "Installing stable jax, jaxlib for tpu"
114117
if [[ -n "$JAX_VERSION" ]]; then
115118
echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}"
116-
pip3 install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
119+
python3 -m uv pip install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
117120
else
118121
echo "Installing stable jax, jaxlib, libtpu
119122
for tpu"
120-
pip3 install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
123+
python3 -m uv pip install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
121124
fi
122125
elif [[ $DEVICE == "gpu" ]]; then
123126
echo "Installing stable jax, jaxlib for NVIDIA gpu"
124127
if [[ -n "$JAX_VERSION" ]]; then
125128
echo "Installing stable jax, jaxlib ${JAX_VERSION}"
126-
pip3 install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
129+
python3 -m uv pip install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
127130
else
128131
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
129-
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
132+
python3 -m uv pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
130133
fi
131134
export NVTE_FRAMEWORK=jax
132-
pip3 install transformer_engine[jax]==2.1.0
135+
python3 -m uv pip install transformer_engine[jax]==2.1.0
133136
fi
134137

135138
elif [[ $MODE == "nightly" ]]; then
@@ -140,22 +143,22 @@ elif [[ $MODE == "nightly" ]]; then
140143
pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
141144
# Install Transformer Engine
142145
export NVTE_FRAMEWORK=jax
143-
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
146+
python3 -m uv pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
144147
elif [[ $DEVICE == "tpu" ]]; then
145148
echo "Installing jax-nightly,jaxlib-nightly"
146149
# Install jax-nightly
147-
pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
150+
python3 -m uv pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
148151
# Install jaxlib-nightly
149-
pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
152+
python3 -m uv pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
150153
# Install libtpu-nightly
151-
pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
154+
python3 -m uv pip install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
152155
fi
153156
echo "Installing nightly tensorboard plugin profile"
154-
pip3 install tbp-nightly --upgrade
157+
python3 -m uv pip install tbp-nightly --upgrade
155158
else
156159
echo -e "\n\nError: You can only set MODE to [stable,nightly].\n\n"
157160
exit 1
158161
fi
159162

160163
# Install maxdiffusion
161-
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
164+
python3 -m uv pip install -U . || echo "Failed to install maxdiffusion" >&2

0 commit comments

Comments
 (0)