Skip to content

Commit ec4166e

Browse files
authored
Fix the Tensorflow version (#154)
* Remove the transformer engine from setup script since it's not required and it's slowing down the setup. * Revert the Transformer Engine change. Fix the tensorflow version.
1 parent 98bb4ee commit ec4166e

3 files changed

Lines changed: 12 additions & 8 deletions

File tree

docs/getting_started/first_run.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ multiple hosts.
1212
1. Clone MaxDiffusion in your TPU VM.
1313
1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running:
1414
```bash
15-
bash setup.sh MODE=stable
15+
If you are running on TPU:
16+
bash setup.sh MODE=stable DEVICE=tpu
17+
18+
If you are running on GPU:
19+
bash setup.sh MODE=stable DEVICE=gpu
1620
```
1721

1822
## Getting Starting: Multihost development

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Pillow
1919
pylint
2020
pyink
2121
pytest==8.2.2
22-
tensorflow>=2.17.0
22+
tensorflow==2.17.0
2323
tensorflow-datasets>=4.9.6
2424
ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
@@ -29,4 +29,4 @@ tokenizers==0.21.0
2929
huggingface_hub==0.24.7
3030
transformers==4.48.1
3131
einops==0.8.0
32-
sentencepiece
32+
sentencepiece

setup.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ fi
5858
# Install JAX and JAXlib based on the specified mode
5959
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
6060
# Stable mode
61-
if [[ $DEVICE == "tpu" ]]; then
61+
if [[ $DEVICE == "tpu" ]]; then
6262
echo "Installing stable jax, jaxlib for tpu"
6363
if [[ -n "$JAX_VERSION" ]]; then
6464
echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}"
@@ -90,14 +90,14 @@ elif [[ $MODE == "nightly" ]]; then
9090
# Install Transformer Engine
9191
export NVTE_FRAMEWORK=jax
9292
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
93-
elif [[ $DEVICE == "tpu" ]]; then
93+
elif [[ $DEVICE == "tpu" ]]; then
9494
echo "Installing jax-nightly,jaxlib-nightly"
9595
# Install jax-nightly
9696
pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
9797
# Install jaxlib-nightly
9898
pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
9999
# Install libtpu-nightly
100-
pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
100+
pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
101101
fi
102102
echo "Installing nightly tensorboard plugin profile"
103103
pip3 install tbp-nightly --upgrade
@@ -107,7 +107,7 @@ else
107107
fi
108108

109109
# Install dependencies from requirements.txt
110-
pip3 install -U -r requirements.txt
110+
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
111111

112112
# Install maxdiffusion
113-
pip3 install -U .
113+
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

0 commit comments

Comments
 (0)