Skip to content

Commit 2790ed2

Browse files
Merge pull request #3300 from AI-Hypercomputer:post_train_docker
PiperOrigin-RevId: 878136407
2 parents 3ce981f + 7f052ca commit 2790ed2

4 files changed

Lines changed: 64 additions & 76 deletions

File tree

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 0 additions & 44 deletions
This file was deleted.

dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.12:${PATH}"
2323
ARG MODE
2424
ENV ENV_MODE=$MODE
2525

26+
ARG WORKFLOW
27+
ENV ENV_WORKFLOW=$WORKFLOW
28+
2629
ARG JAX_VERSION
2730
ENV ENV_JAX_VERSION=$JAX_VERSION
2831

@@ -43,14 +46,15 @@ WORKDIR /deps
4346
# Copy setup files and dependency files separately for better caching
4447
COPY tools/setup tools/setup/
4548
COPY dependencies/requirements/ dependencies/requirements/
46-
COPY src/install_maxtext_extra_deps/extra_deps_from_github.txt src/install_maxtext_extra_deps/
49+
COPY src/install_maxtext_extra_deps/ src/install_maxtext_extra_deps/
50+
COPY src/MaxText/integration/vllm/ src/MaxText/integration/vllm/
4751

4852
# Copy the custom libtpu.so file if it exists inside maxtext repository
4953
COPY libtpu.so* /root/custom_libtpu/
5054

5155
# Install dependencies - these steps are cached unless the copied files change
52-
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_VERSION=$ENV_LIBTPU_VERSION DEVICE=${ENV_DEVICE}"
53-
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_VERSION=${ENV_LIBTPU_VERSION} DEVICE=${ENV_DEVICE}
56+
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE WORKFLOW=$ENV_WORKFLOW JAX_VERSION=$ENV_JAX_VERSION LIBTPU_VERSION=$ENV_LIBTPU_VERSION DEVICE=${ENV_DEVICE}"
57+
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} WORKFLOW=${ENV_WORKFLOW} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_VERSION=${ENV_LIBTPU_VERSION} DEVICE=${ENV_DEVICE}
5458

5559
# Now copy the remaining code (source files that may change frequently)
5660
COPY . .

dependencies/scripts/docker_build_dependency_image.sh

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,20 @@ run_docker_build() {
118118
docker build --network host $(printf -- '--build-arg %q ' "$@") -f "$dockerfile_path" -t "$LOCAL_IMAGE_NAME" .
119119
}
120120

121-
# Function to build post-training image
122-
build_post_training_image() {
123-
DOCKERFILE_NAME=""
124-
if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then
125-
# To install vllm, tunix, tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
126-
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
127-
rsync -a --exclude='__pycache__' ../tpu-inference .
128-
rsync -a --exclude='__pycache__' ../vllm .
129-
rsync -a --exclude='__pycache__' ../tunix .
130-
131-
# The cleanup is set to run even if the build fails to remove the copied directory.
132-
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM
133-
134-
DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
135-
echo "Building local post-training dependencies: $DOCKERFILE_NAME"
136-
else
137-
DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile'
138-
echo "Building remote post-training dependencies: $DOCKERFILE_NAME"
139-
fi
121+
# Function to build post-training dependencies from local Github head
122+
build_post_training_deps_from_local_github() {
123+
# To install vllm, tunix, tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
124+
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
125+
rsync -a --exclude='__pycache__' ../tpu-inference .
126+
rsync -a --exclude='__pycache__' ../vllm .
127+
rsync -a --exclude='__pycache__' ../tunix .
128+
129+
# The cleanup is set to run even if the build fails to remove the copied directory.
130+
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM
131+
132+
DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
133+
echo "Building local post-training dependencies: $DOCKERFILE_NAME"
134+
140135
run_docker_build "$MAXTEXT_REPO_ROOT/dependencies/dockerfiles/$DOCKERFILE_NAME" \
141136
"MODE=${WORKFLOW}" "BASEIMAGE=${LOCAL_IMAGE_NAME}"
142137
}
@@ -170,7 +165,9 @@ build_tpu_image() {
170165

171166
# Handle post-training workflow if specified
172167
if [[ ${WORKFLOW} == "post-training" || ${WORKFLOW} == "post-training-experimental" ]]; then
173-
build_post_training_image
168+
if [[ ${POST_TRAINING_SOURCE} == "local" ]]; then
169+
build_post_training_deps_from_local_github
170+
fi
174171
fi
175172
}
176173

tools/setup/setup.sh

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ if [[ -z "$MODE" ]]; then
154154
export MODE=stable
155155
fi
156156

157+
# Set default value for $WORKFLOW
158+
if [[ -z "$WORKFLOW" ]]; then
159+
export WORKFLOW=pre-training
160+
fi
161+
157162
# Unset optional variables if set to NONE
158163
unset_optional_vars() {
159164
local optional_vars=("JAX_VERSION" "LIBTPU_VERSION" "LIBTPU_GCS_PATH")
@@ -185,6 +190,19 @@ install_custom_libtpu() {
185190
gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path"
186191
}
187192

193+
install_maxtext_package_without_deps() {
194+
# The MaxText package is installed separately from its dependencies to optimize
195+
# docker image rebuild times by leveraging docker's layer caching.
196+
# Dependencies are installed in a separate step before MaxText code is
197+
# copied. This means that if MaxText code changes, but the
198+
# dependencies do not, docker can reuse the cached dependency layer, leading
199+
# to significantly faster image builds.
200+
if [ -f 'pyproject.toml' ]; then
201+
echo "Installing MaxText package without installing the dependencies (already installed)"
202+
python3 -m uv pip install --no-deps -e .
203+
fi
204+
}
205+
188206
install_maxtext_with_deps() {
189207
if [[ "$DEVICE" != "tpu" && "$DEVICE" != "gpu" ]]; then
190208
echo -e "\n\nError: DEVICE must be either 'tpu' or 'gpu'.\n\n"
@@ -200,18 +218,31 @@ install_maxtext_with_deps() {
200218
python3 -m uv pip install --resolution=lowest -r "$dep_name" \
201219
-r 'src/install_maxtext_extra_deps/extra_deps_from_github.txt'
202220

203-
# The MaxText package is installed separately from its dependencies to optimize
204-
# docker image rebuild times by leveraging docker's layer caching.
205-
# Dependencies are installed in a separate step before MaxText code is
206-
# copied. This means that if MaxText code changes, but the
207-
# dependencies do not, docker can reuse the cached dependency layer, leading
208-
# to significantly faster image builds.
209-
if [ -f 'pyproject.toml' ]; then
210-
echo "Installing MaxText package without installing the dependencies (already installed)"
211-
python3 -m uv pip install --no-deps -e .
221+
install_maxtext_package_without_deps
222+
}
223+
224+
install_post_training_deps() {
225+
if [[ "$DEVICE" != "tpu" ]]; then
226+
echo -e "\n\nError: DEVICE must be 'tpu'.\n\n"
227+
exit 1
212228
fi
229+
echo "Setting up MaxText post-training workflow for $DEVICE device"
230+
dep_name='dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt'
231+
echo "Installing requirements from $dep_name"
232+
python3 -m uv pip install --resolution=lowest -r "$dep_name"
233+
python3 -m src.install_maxtext_extra_deps.install_post_train_extra_deps
213234
}
214235

236+
# ---------- Post-Training workflow installation ----------
237+
238+
if [[ "$WORKFLOW" == "post-training" ]]; then
239+
install_post_training_deps
240+
install_maxtext_package_without_deps
241+
exit 0
242+
fi
243+
244+
# ---------- Pre-Training workflow installation ----------
245+
215246
# stable mode installation
216247
if [[ "$MODE" == "stable" ]]; then
217248
install_maxtext_with_deps

0 commit comments

Comments
 (0)