Skip to content

Commit e556ca1

Browse files
Make Maxdiffusion compatile w/ JAX 0.7.0 and latest orbax version (#211)
* Unpinned orbax and updated orbax logger API * Upgraded from python 3.10 to 3.12 for base image * Use python 3.12 in unit tests
1 parent 4e0999c commit e556ca1

5 files changed

Lines changed: 12 additions & 8 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ jobs:
3535
name: "TPU test (${{ matrix.tpu-type }})"
3636
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
3737
steps:
38-
- uses: actions/checkout@v3
38+
- uses: actions/checkout@v4
39+
- name: Set up Python 3.12
40+
uses: actions/setup-python@v5
41+
with:
42+
python-version: '3.12'
3943
- name: Install dependencies
4044
run: |
4145
pip install -e .

maxdiffusion_dependencies.Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Use Python 3.10-slim-bullseye as the base image
2-
FROM python:3.10-slim-bullseye
1+
# Use Python 3.12-slim-bullseye as the base image
2+
FROM python:3.12-slim-bullseye
33

44
# Environment variable for no-cache-dir and pip root user warning
55
ENV PIP_NO_CACHE_DIR=1
66
ENV PIP_ROOT_USER_ACTION=ignore
77

88
# Set environment variables for Google Cloud SDK and Python 3.10
9-
ENV PYTHON_VERSION=3.10
9+
ENV PYTHON_VERSION=3.12
1010
ENV CLOUD_SDK_VERSION=latest
1111

1212
# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ tensorflow>=2.17.0
2424
tensorflow-datasets>=4.9.6
2525
ruff>=0.1.5,<=0.2
2626
opencv-python-headless==4.10.0.84
27-
orbax-checkpoint==0.10.3
27+
orbax-checkpoint
2828
tokenizers==0.21.0
2929
huggingface_hub>=0.30.2
3030
transformers==4.48.1

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ tensorflow>=2.17.0
2626
tensorflow-datasets>=4.9.6
2727
ruff>=0.1.5,<=0.2
2828
opencv-python-headless==4.10.0.84
29-
orbax-checkpoint==0.10.3
29+
orbax-checkpoint
3030
tokenizers==0.21.0
3131
huggingface_hub>=0.30.2
3232
transformers==4.48.1

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax.training import train_state
2929
import orbax
3030
import orbax.checkpoint as ocp
31-
from orbax.checkpoint.logging import abstract_logger
31+
from orbax.checkpoint.logging import AbstractLogger
3232
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
@@ -43,7 +43,7 @@ def create_orbax_checkpoint_manager(
4343
checkpoint_type: str,
4444
dataset_type: str = "tf",
4545
use_async: bool = True,
46-
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
46+
orbax_logger: Optional[AbstractLogger] = None,
4747
):
4848
"""
4949
Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.

0 commit comments

Comments
 (0)