@@ -154,6 +154,11 @@ if [[ -z "$MODE" ]]; then
154154 export MODE=stable
155155fi
156156
157+ # Set default value for $WORKFLOW
158+ if [[ -z " $WORKFLOW " ]]; then
159+ export WORKFLOW=pre-training
160+ fi
161+
157162# Unset optional variables if set to NONE
158163unset_optional_vars () {
159164 local optional_vars=(" JAX_VERSION" " LIBTPU_VERSION" " LIBTPU_GCS_PATH" )
@@ -185,6 +190,19 @@ install_custom_libtpu() {
185190 gsutil cp " $LIBTPU_GCS_PATH " " $libtpu_path "
186191}
187192
193+ install_maxtext_package_without_deps () {
194+ # The MaxText package is installed separately from its dependencies to optimize
195+ # docker image rebuild times by leveraging docker's layer caching.
196+ # Dependencies are installed in a separate step before MaxText code is
197+ # copied. This means that if MaxText code changes, but the
198+ # dependencies do not, docker can reuse the cached dependency layer, leading
199+ # to significantly faster image builds.
200+ if [ -f ' pyproject.toml' ]; then
201+ echo " Installing MaxText package without installing the dependencies (already installed)"
202+ python3 -m uv pip install --no-deps -e .
203+ fi
204+ }
205+
188206install_maxtext_with_deps () {
189207 if [[ " $DEVICE " != " tpu" && " $DEVICE " != " gpu" ]]; then
190208 echo -e " \n\nError: DEVICE must be either 'tpu' or 'gpu'.\n\n"
@@ -200,18 +218,31 @@ install_maxtext_with_deps() {
200218 python3 -m uv pip install --resolution=lowest -r " $dep_name " \
201219 -r ' src/install_maxtext_extra_deps/extra_deps_from_github.txt'
202220
203- # The MaxText package is installed separately from its dependencies to optimize
204- # docker image rebuild times by leveraging docker's layer caching.
205- # Dependencies are installed in a separate step before MaxText code is
206- # copied. This means that if MaxText code changes, but the
207- # dependencies do not, docker can reuse the cached dependency layer, leading
208- # to significantly faster image builds.
209- if [ -f ' pyproject.toml' ]; then
210- echo " Installing MaxText package without installing the dependencies (already installed)"
211- python3 -m uv pip install --no-deps -e .
221+ install_maxtext_package_without_deps
222+ }
223+
224+ install_post_training_deps () {
225+ if [[ " $DEVICE " != " tpu" ]]; then
226+ echo -e " \n\nError: DEVICE must be 'tpu'.\n\n"
227+ exit 1
212228 fi
229+ echo " Setting up MaxText post-training workflow for $DEVICE device"
230+ dep_name=' dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt'
231+ echo " Installing requirements from $dep_name "
232+ python3 -m uv pip install --resolution=lowest -r " $dep_name "
233+ python3 -m src.install_maxtext_extra_deps.install_post_train_extra_deps
213234}
214235
236+ # ---------- Post-Training workflow installation ----------
237+
238+ if [[ " $WORKFLOW " == " post-training" ]]; then
239+ install_post_training_deps
240+ install_maxtext_package_without_deps
241+ exit 0
242+ fi
243+
244+ # ---------- Pre-Training workflow installation ----------
245+
215246# stable mode installation
216247if [[ " $MODE " == " stable" ]]; then
217248 install_maxtext_with_deps
0 commit comments