@@ -21,28 +21,28 @@ ENV MODE=$MODE
2121RUN echo "Installing Post-Training dependencies (tunix, vLLM, tpu-inference) with MODE=${MODE}"
2222RUN pip uninstall -y jax jaxlib libtpu
2323
24- RUN pip install aiohttp==3.12.15
24+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv pip install aiohttp==3.12.15
2525
2626# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
27- RUN pip install keyring keyrings.google-artifactregistry-auth
27+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv pip install keyring keyrings.google-artifactregistry-auth
2828
29- RUN pip install numba==0.61.2
29+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv pip install numba==0.61.2
3030
3131COPY tunix /tunix
3232RUN pip uninstall -y google-tunix
33- RUN pip install -e /tunix --no-cache-dir
33+ RUN --mount=type=cache,target=/root/.cache/ pip --mount=type=cache,target=/root/.cache/uv pip install -e /tunix
3434
3535COPY vllm /vllm
36- RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir
36+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv VLLM_TARGET_DEVICE="tpu" pip install -e /vllm
3737
3838COPY tpu-inference /tpu-inference
39- RUN pip install -e /tpu-inference --no-cache-dir
39+ RUN --mount=type=cache,target=/root/.cache/ pip --mount=type=cache,target=/root/.cache/uv pip install -e /tpu-inference
4040
41- RUN pip install --no-deps qwix==0.1.4
41+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv pip install --no-deps qwix==0.1.4
4242
43- RUN pip install math-verify==0.9.0
43+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv pip install math-verify==0.9.0
4444
45- RUN if [ "$MODE" = "post-training-experimental" ]; then \
45+ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/uv if [ "$MODE" = "post-training-experimental" ]; then \
4646 echo "MODE=post-training-experimental: Re-installing JAX/libtpu" ; \
4747 pip uninstall -y jax jaxlib libtpu && \
4848 pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
0 commit comments