From ac910b1c01daf0eb5f60e569c908c27350955750 Mon Sep 17 00:00:00 2001 From: susanbao Date: Mon, 6 Oct 2025 17:27:51 +0000 Subject: [PATCH 1/2] OOM issue for JAX 0.7.2 --- requirements.txt | 2 +- requirements_with_jax_ai_image.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6d5e29024..11e45c7df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2 +jax>=0.6.2,<=0.7.0 jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 955a5e76f..e8b3b0227 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -1,7 +1,7 @@ # Requirements for Building the MaxDifussion Docker Image # These requirements are additional to the dependencies present in the JAX AI base image. --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2 +jax>=0.6.2,<=0.7.0 jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 From 1fb8fe2a24bf39e2efd888d6474500cdedb7bc5a Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Mon, 6 Oct 2025 21:41:10 +0000 Subject: [PATCH 2/2] use the latest jax head --- requirements.txt | 2 +- requirements_with_jax_ai_image.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 11e45c7df..c545288b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2,<=0.7.0 +jax@git+https://github.com/jax-ml/jax.git jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0 diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index e8b3b0227..834f8a70a 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -1,7 +1,7 @@ # Requirements for Building the MaxDifussion Docker Image # These requirements are additional to the dependencies present in the JAX AI base image. --extra-index-url https://download.pytorch.org/whl/cpu -jax>=0.6.2,<=0.7.0 +jax@git+https://github.com/jax-ml/jax.git jaxlib>=0.4.30 grain google-cloud-storage>=2.17.0