Skip to content

Commit 3267ec9

Browse files
Merge branch 'main' into wan_transformer
2 parents 82b719e + 4a8155e commit 3267ec9

21 files changed

Lines changed: 1564 additions & 47 deletions

.github/workflows/UploadDockerImages.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ jobs:
2929
- uses: actions/checkout@v3
3030
- name: Cleanup old docker images
3131
run: docker system prune --all --force
32-
- name: build maxdiffusion jax stable stack image
32+
- name: build maxdiffusion jax ai image
3333
run: |
34-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest
34+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
3535
- name: build maxdiffusion jax nightly image
3636
run: |
3737
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
@@ -44,7 +44,7 @@ jobs:
4444
run: docker system prune --all --force
4545
- name: build maxdiffusion jax stable stack gpu image
4646
run: |
47-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
47+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu
4848
- name: build maxdiffusion jax nightly image
4949
run: |
5050
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

docker_build_dependency_image.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
# Each time you update the base image via a "bash docker_maxdiffusion_image_upload.sh", there will be a slow upload process
2121
# (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds.
2222

23-
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}}
23+
# bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
24+
# Note: The mode stable_stack is marked for deprecation, please use MODE=jax_ai_image instead
25+
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
2426
# bash docker_build_dependency_image.sh MODE=nightly
2527
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
2628
# bash docker_build_dependency_image.sh MODE=stable
@@ -69,17 +71,17 @@ if [[ ${DEVICE} == "gpu" ]]; then
6971
fi
7072
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
7173
else
72-
if [[ "${MODE}" == "stable_stack" ]]; then
74+
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
7375
if [[ ! -v BASEIMAGE ]]; then
7476
echo "Erroring out because BASEIMAGE is unset, please set it!"
7577
exit 1
7678
fi
7779
docker build --no-cache \
78-
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
80+
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
7981
--build-arg COMMIT_HASH=${COMMIT_HASH} \
8082
--network=host \
8183
-t ${LOCAL_IMAGE_NAME} \
82-
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
84+
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
8385
else
8486
docker build --no-cache \
8587
--network=host \

docs/getting_started/run_maxdiffusion_via_xpk.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,23 @@ after which log out and log back in to the machine.
6262
bash docker_build_dependency_image.sh
6363
```
6464

65-
#### New: Build MaxDiffusion Docker Image with JAX Stable Stack
66-
We're excited to announce that you can build the MaxDiffusion Docker image using the JAX Stable Stack base image. This provides a more reliable and consistent build environment.
65+
#### New: Build MaxDiffusion Docker Image with JAX AI Images (Formerly known as JAX Stable Stack)
66+
We're excited to announce that you can build the MaxDiffusion Docker image using the JAX AI base image. This provides a more reliable and consistent build environment.
6767
68-
###### What is JAX Stable Stack?
69-
JAX Stable Stack provides a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions.
68+
###### What is JAX AI Images?
69+
JAX AI Images provide a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions.
7070
7171
###### How to Use It
72-
To build the MaxDiffusion Docker image with JAX Stable Stack, simply set the MODE to `stable_stack` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:
72+
To build the MaxDiffusion Docker image with JAX AI Images, simply set the MODE to `jax_ai_image` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:
7373
7474
```
75-
# Example bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1
76-
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_BASEIMAGE}}
75+
# Example bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2
76+
bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE_BASEIMAGE}}
7777
```
7878
79-
You can find a list of available JAX Stable Stack base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu).
79+
You can find a list of available JAX AI base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu).
8080
81-
**Important Note:** The JAX Stable Stack is currently in the experimental phase. We encourage you to try it out and provide feedback.
81+
**Important Note:** JAX AI Images is currently in the experimental phase. We encourage you to try it out and provide feedback.
8282
8383
3. After building the dependency image `maxdiffusion_base_image`, xpk can handle updates to the working directory when running `xpk workload create` and using `--base-docker-image`.
8484
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
ARG JAX_AI_IMAGE_BASEIMAGE
2+
3+
# JAX AI Base Image
4+
FROM $JAX_AI_IMAGE_BASEIMAGE
5+
6+
ARG COMMIT_HASH
7+
8+
ENV COMMIT_HASH=$COMMIT_HASH
9+
10+
RUN mkdir -p /deps
11+
12+
# Set the working directory in the container
13+
WORKDIR /deps
14+
15+
# Copy all files from local workspace into docker container
16+
COPY . .
17+
18+
# Install Maxdiffusion Jax AI Image requirements
19+
RUN pip install -r /deps/requirements_with_jax_ai_image.txt
20+
21+
# Run the script available in JAX-AI-Image base image to generate the manifest file
22+
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH

maxdiffusion_jax_stable_stack_tpu.Dockerfile

Lines changed: 0 additions & 22 deletions
This file was deleted.

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ huggingface_hub==0.30.2
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece
33-
aqtp
33+
aqtp
34+
imageio==2.37.0
35+
imageio-ffmpeg==0.6.0
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Requirements for Building the MaxDifussion Docker Image
2-
# These requirements are additional to the dependencies present in the JAX Stable Stack base image.
2+
# These requirements are additional to the dependencies present in the JAX AI base image.
33
absl-py
44
datasets
55
einops==0.8.0
@@ -31,4 +31,6 @@ tensorflow-datasets>=4.9.6
3131
tokenizers==0.21.0
3232
torch==2.5.1
3333
torchvision==0.20.1
34-
transformers==4.48.1
34+
transformers==4.48.1
35+
imageio==2.37.0
36+
imageio-ffmpeg==0.6.0

src/maxdiffusion/configs/base14.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ diffusion_scheduler_config: {
9292

9393
# Hardware
9494
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
95+
skip_jax_distributed_system: False
9596

9697
base_output_directory: ""
9798

src/maxdiffusion/configs/base21.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ diffusion_scheduler_config: {
9191

9292
# Hardware
9393
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
94+
skip_jax_distributed_system: False
9495

9596
# Output directory
9697
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ diffusion_scheduler_config: {
104104

105105
# Hardware
106106
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
107+
skip_jax_distributed_system: False
107108

108109
# Output directory
109110
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"

0 commit comments

Comments
 (0)