diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index 331caeb3a..4af852141 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -29,9 +29,9 @@ jobs: - uses: actions/checkout@v3 - name: Cleanup old docker images run: docker system prune --all --force - - name: build maxdiffusion jax stable stack image + - name: build maxdiffusion jax ai image run: | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest - name: build maxdiffusion jax nightly image run: | bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index e52e1c78f..d8838ec24 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -20,7 +20,9 @@ # Each time you update the base image via a "bash docker_maxdiffusion_image_upload.sh", there will be a slow upload process # (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds. -# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}} +# bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}} +# Note: The mode stable_stack is marked for deprecation, please use MODE=jax_ai_image instead +# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}} # bash docker_build_dependency_image.sh MODE=nightly # bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 # bash docker_build_dependency_image.sh MODE=stable @@ -69,17 +71,17 @@ if [[ ${DEVICE} == "gpu" ]]; then fi docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else - if [[ "${MODE}" == "stable_stack" ]]; then + if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then if [[ ! -v BASEIMAGE ]]; then echo "Erroring out because BASEIMAGE is unset, please set it!" exit 1 fi docker build --no-cache \ - --build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ + --build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \ --build-arg COMMIT_HASH=${COMMIT_HASH} \ --network=host \ -t ${LOCAL_IMAGE_NAME} \ - -f maxdiffusion_jax_stable_stack_tpu.Dockerfile . + -f maxdiffusion_jax_ai_image_tpu.Dockerfile . else docker build --no-cache \ --network=host \ diff --git a/docs/getting_started/run_maxdiffusion_via_xpk.md b/docs/getting_started/run_maxdiffusion_via_xpk.md index 87d43f23e..c2e62ffcb 100644 --- a/docs/getting_started/run_maxdiffusion_via_xpk.md +++ b/docs/getting_started/run_maxdiffusion_via_xpk.md @@ -62,23 +62,23 @@ after which log out and log back in to the machine. bash docker_build_dependency_image.sh ``` - #### New: Build MaxDiffusion Docker Image with JAX Stable Stack - We're excited to announce that you can build the MaxDiffusion Docker image using the JAX Stable Stack base image. This provides a more reliable and consistent build environment. + #### New: Build MaxDiffusion Docker Image with JAX AI Images (Formerly known as JAX Stable Stack) + We're excited to announce that you can build the MaxDiffusion Docker image using the JAX AI base image. This provides a more reliable and consistent build environment. - ###### What is JAX Stable Stack? - JAX Stable Stack provides a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions. + ###### What is JAX AI Images? + JAX AI Images provide a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions. ###### How to Use It - To build the MaxDiffusion Docker image with JAX Stable Stack, simply set the MODE to `stable_stack` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script: + To build the MaxDiffusion Docker image with JAX AI Images, simply set the MODE to `jax_ai_image` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script: ``` - # Example bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1 - bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_BASEIMAGE}} + # Example bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2 + bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE_BASEIMAGE}} ``` - You can find a list of available JAX Stable Stack base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu). + You can find a list of available JAX AI base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu). - **Important Note:** The JAX Stable Stack is currently in the experimental phase. We encourage you to try it out and provide feedback. + **Important Note:** JAX AI Images is currently in the experimental phase. We encourage you to try it out and provide feedback. 3. After building the dependency image `maxdiffusion_base_image`, xpk can handle updates to the working directory when running `xpk workload create` and using `--base-docker-image`. diff --git a/maxdiffusion_jax_ai_image_tpu.Dockerfile b/maxdiffusion_jax_ai_image_tpu.Dockerfile new file mode 100644 index 000000000..cab50feef --- /dev/null +++ b/maxdiffusion_jax_ai_image_tpu.Dockerfile @@ -0,0 +1,22 @@ +ARG JAX_AI_IMAGE_BASEIMAGE + +# JAX AI Base Image +FROM $JAX_AI_IMAGE_BASEIMAGE + +ARG COMMIT_HASH + +ENV COMMIT_HASH=$COMMIT_HASH + +RUN mkdir -p /deps + +# Set the working directory in the container +WORKDIR /deps + +# Copy all files from local workspace into docker container +COPY . . + +# Install Maxdiffusion Jax AI Image requirements +RUN pip install -r /deps/requirements_with_jax_ai_image.txt + +# Run the script available in JAX-AI-Image base image to generate the manifest file +RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/maxdiffusion_jax_stable_stack_tpu.Dockerfile b/maxdiffusion_jax_stable_stack_tpu.Dockerfile deleted file mode 100644 index 5b3602abd..000000000 --- a/maxdiffusion_jax_stable_stack_tpu.Dockerfile +++ /dev/null @@ -1,22 +0,0 @@ -ARG JAX_STABLE_STACK_BASEIMAGE - -# JAX Stable Stack Base Image -FROM $JAX_STABLE_STACK_BASEIMAGE - -ARG COMMIT_HASH - -ENV COMMIT_HASH=$COMMIT_HASH - -RUN mkdir -p /deps - -# Set the working directory in the container -WORKDIR /deps - -# Copy all files from local workspace into docker container -COPY . . - -# Install Maxdiffusion jax stable stack requirements -RUN pip install -r /deps/requirements_with_jax_stable_stack.txt - -# Run the script available in JAX-Stable-Stack base image to generate the manifest file -RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_ai_image.txt similarity index 95% rename from requirements_with_jax_stable_stack.txt rename to requirements_with_jax_ai_image.txt index 5a88c800f..760755ffe 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_ai_image.txt @@ -1,5 +1,5 @@ # Requirements for Building the MaxDifussion Docker Image -# These requirements are additional to the dependencies present in the JAX Stable Stack base image. +# These requirements are additional to the dependencies present in the JAX AI base image. absl-py datasets einops==0.8.0