Skip to content

Commit eb96e9e

Browse files
Update nightly to install flax at head (#296)
pin flax install Install only if nightly Syntax error in dockerfile fixed Update baseimage arg to build time as well
1 parent 5cbf844 commit eb96e9e

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

maxdiffusion_jax_ai_image_tpu.Dockerfile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ ARG JAX_AI_IMAGE_BASEIMAGE
33
# JAX AI Base Image
44
FROM $JAX_AI_IMAGE_BASEIMAGE
55

6+
ARG JAX_AI_IMAGE_BASEIMAGE
7+
68
ARG COMMIT_HASH
79

810
ENV COMMIT_HASH=$COMMIT_HASH
@@ -18,5 +20,12 @@ COPY . .
1820
# Install Maxdiffusion Jax AI Image requirements
1921
RUN pip install -r /deps/requirements_with_jax_ai_image.txt
2022

23+
# TODO: Remove the flax pin and fsspec overrides once flax stable version releases
24+
RUN if echo "$JAX_AI_IMAGE_BASEIMAGE" | grep -q "nightly"; then \
25+
echo "Nightly build detected: Installing specific Flax commit and fsspec." && \
26+
pip install --upgrade --force-reinstall git+https://github.com/google/flax.git@ef78d6584623511746be4824965cdef42b464583 && \
27+
pip install "fsspec==2025.10.0"; \
28+
fi
29+
2130
# Run the script available in JAX-AI-Image base image to generate the manifest file
2231
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH

0 commit comments

Comments
 (0)