Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ jobs:
- uses: actions/checkout@v3
- name: Cleanup old docker images
run: docker system prune --all --force
- name: build maxdiffusion jax stable stack image
- name: build maxdiffusion jax ai image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
- name: build maxdiffusion jax nightly image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
Expand Down
10 changes: 6 additions & 4 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
# Each time you update the base image via a "bash docker_maxdiffusion_image_upload.sh", there will be a slow upload process
# (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds.

# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}}
# bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
Comment thread
Rohan-Bierneni marked this conversation as resolved.
# Note: The mode stable_stack is marked for deprecation, please use MODE=jax_ai_image instead
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
# bash docker_build_dependency_image.sh MODE=nightly
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
# bash docker_build_dependency_image.sh MODE=stable
Expand Down Expand Up @@ -69,17 +71,17 @@ if [[ ${DEVICE} == "gpu" ]]; then
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
if [[ "${MODE}" == "stable_stack" ]]; then
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
else
docker build --no-cache \
--network=host \
Expand Down
18 changes: 9 additions & 9 deletions docs/getting_started/run_maxdiffusion_via_xpk.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,23 @@ after which log out and log back in to the machine.
bash docker_build_dependency_image.sh
```

#### New: Build MaxDiffusion Docker Image with JAX Stable Stack
Comment thread
Rohan-Bierneni marked this conversation as resolved.
We're excited to announce that you can build the MaxDiffusion Docker image using the JAX Stable Stack base image. This provides a more reliable and consistent build environment.
#### New: Build MaxDiffusion Docker Image with JAX AI Images (Formerly known as JAX Stable Stack)
We're excited to announce that you can build the MaxDiffusion Docker image using the JAX AI base image. This provides a more reliable and consistent build environment.

###### What is JAX Stable Stack?
JAX Stable Stack provides a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions.
###### What is JAX AI Images?
JAX AI Images provide a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions.

###### How to Use It
To build the MaxDiffusion Docker image with JAX Stable Stack, simply set the MODE to `stable_stack` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:
To build the MaxDiffusion Docker image with JAX AI Images, simply set the MODE to `jax_ai_image` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:

```
# Example bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_BASEIMAGE}}
# Example bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2
bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE_BASEIMAGE}}
```

You can find a list of available JAX Stable Stack base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu).
You can find a list of available JAX AI base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu).

**Important Note:** The JAX Stable Stack is currently in the experimental phase. We encourage you to try it out and provide feedback.
**Important Note:** JAX AI Images is currently in the experimental phase. We encourage you to try it out and provide feedback.

3. After building the dependency image `maxdiffusion_base_image`, xpk can handle updates to the working directory when running `xpk workload create` and using `--base-docker-image`.

Expand Down
22 changes: 22 additions & 0 deletions maxdiffusion_jax_ai_image_tpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
ARG JAX_AI_IMAGE_BASEIMAGE

# JAX AI Base Image
FROM $JAX_AI_IMAGE_BASEIMAGE

ARG COMMIT_HASH

ENV COMMIT_HASH=$COMMIT_HASH

RUN mkdir -p /deps

# Set the working directory in the container
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .

# Install Maxdiffusion Jax AI Image requirements
RUN pip install -r /deps/requirements_with_jax_ai_image.txt

# Run the script available in JAX-AI-Image base image to generate the manifest file
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
22 changes: 0 additions & 22 deletions maxdiffusion_jax_stable_stack_tpu.Dockerfile

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Requirements for Building the MaxDifussion Docker Image
# These requirements are additional to the dependencies present in the JAX Stable Stack base image.
# These requirements are additional to the dependencies present in the JAX AI base image.
absl-py
datasets
einops==0.8.0
Expand Down
Loading