@@ -35,13 +35,14 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th
3535 if [[ $REPLY =~ ^[Yy]$ ]]; then
3636 # Check if uv is installed first; if not, install uv
3737 if ! command -v uv & > /dev/null; then
38- echo -e " \n'uv' command not found. Installing it now via the official installer..."
39- curl -LsSf https://astral.sh/uv/install.sh | sh
38+ # echo -e "\n'uv' command not found. Installing it now via the official installer..."
39+ # curl -LsSf https://astral.sh/uv/install.sh | sh
4040
41- echo -e " \n\e[33m'uv' has been installed.\e[0m"
42- echo " The installer likely printed instructions to update your shell's PATH."
43- echo " Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script."
44- exit 1
41+ # echo -e "\n\e[33m'uv' has been installed.\e[0m"
42+ # echo "The installer likely printed instructions to update your shell's PATH."
43+ # echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script."
44+ # exit 1
45+ pip install uv
4546 fi
4647 maxdiffusion_dir=$( pwd)
4748 cd
@@ -53,7 +54,7 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th
5354 echo " No name provided. Using default name: '$venv_name '"
5455 fi
5556 echo " Creating virtual environment '$venv_name ' with Python 3.12..."
56- uv venv --python 3.12 " $venv_name " --seed
57+ python3 -m uv venv --python 3.12 " $venv_name " --seed
5758 printf ' %s\n' " $( realpath -- " $venv_name " ) " >> /tmp/venv_created
5859 echo -e " \n\e[32mVirtual environment '$venv_name ' created successfully!\e[0m"
5960 echo " To activate it, run the following command:"
@@ -81,6 +82,8 @@ apt update -y && apt -y install gcsfuse
8182rm -rf /var/lib/apt/lists/*
8283EOF
8384
85+ python3 -m pip install -U setuptools wheel uv
86+
8487# Set environment variables from command line arguments
8588for ARGUMENT in " $@ " ; do
8689 IFS=' =' read -r KEY VALUE <<< " $ARGUMENT"
@@ -104,7 +107,7 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
104107fi
105108
106109# Install dependencies from requirements.txt first
107- pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
110+ python3 -m uv pip install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
108111
109112# Install JAX and JAXlib based on the specified mode
110113if [[ " $MODE " == " stable" || ! -v MODE ]]; then
@@ -113,23 +116,23 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
113116 echo " Installing stable jax, jaxlib for tpu"
114117 if [[ -n " $JAX_VERSION " ]]; then
115118 echo " Installing stable jax, jaxlib, libtpu version ${JAX_VERSION} "
116- pip3 install " jax[tpu]==${JAX_VERSION} " -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
119+ python3 -m uv pip install " jax[tpu]==${JAX_VERSION} " -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
117120 else
118121 echo " Installing stable jax, jaxlib, libtpu
119122 for tpu"
120- pip3 install ' jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
123+ python3 -m uv pip install ' jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
121124 fi
122125 elif [[ $DEVICE == " gpu" ]]; then
123126 echo " Installing stable jax, jaxlib for NVIDIA gpu"
124127 if [[ -n " $JAX_VERSION " ]]; then
125128 echo " Installing stable jax, jaxlib ${JAX_VERSION} "
126- pip3 install -U " jax[cuda12]==${JAX_VERSION} " -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
129+ python3 -m uv pip install -U " jax[cuda12]==${JAX_VERSION} " -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
127130 else
128131 echo " Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
129- pip3 install " jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
132+ python3 -m uv pip install " jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
130133 fi
131134 export NVTE_FRAMEWORK=jax
132- pip3 install transformer_engine[jax]==2.1.0
135+ python3 -m uv pip install transformer_engine[jax]==2.1.0
133136 fi
134137
135138elif [[ $MODE == " nightly" ]]; then
@@ -140,22 +143,22 @@ elif [[ $MODE == "nightly" ]]; then
140143 pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
141144 # Install Transformer Engine
142145 export NVTE_FRAMEWORK=jax
143- pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
146+ python3 -m uv pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
144147 elif [[ $DEVICE == " tpu" ]]; then
145148 echo " Installing jax-nightly,jaxlib-nightly"
146149 # Install jax-nightly
147- pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
150+ python3 -m uv pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
148151 # Install jaxlib-nightly
149- pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
152+ python3 -m uv pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
150153 # Install libtpu-nightly
151- pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
154+ python3 -m uv pip install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
152155 fi
153156 echo " Installing nightly tensorboard plugin profile"
154- pip3 install tbp-nightly --upgrade
157+ python3 -m uv pip install tbp-nightly --upgrade
155158else
156159 echo -e " \n\nError: You can only set MODE to [stable,nightly].\n\n"
157160 exit 1
158161fi
159162
160163# Install maxdiffusion
161- pip3 install -U . || echo " Failed to install maxdiffusion" >&2
164+ python3 -m uv pip install -U . || echo " Failed to install maxdiffusion" >&2
0 commit comments