diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 728d2f2e3..f65273620 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -35,7 +35,11 @@ jobs: name: "TPU test (${{ matrix.tpu-type }})" runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' - name: Install dependencies run: | pip install -e . diff --git a/maxdiffusion_dependencies.Dockerfile b/maxdiffusion_dependencies.Dockerfile index c12707c39..5e792eb82 100644 --- a/maxdiffusion_dependencies.Dockerfile +++ b/maxdiffusion_dependencies.Dockerfile @@ -1,12 +1,12 @@ -# Use Python 3.10-slim-bullseye as the base image -FROM python:3.10-slim-bullseye +# Use Python 3.12-slim-bullseye as the base image +FROM python:3.12-slim-bullseye # Environment variable for no-cache-dir and pip root user warning ENV PIP_NO_CACHE_DIR=1 ENV PIP_ROOT_USER_ACTION=ignore # Set environment variables for Google Cloud SDK and Python 3.10 -ENV PYTHON_VERSION=3.10 +ENV PYTHON_VERSION=3.12 ENV CLOUD_SDK_VERSION=latest # Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors diff --git a/requirements.txt b/requirements.txt index 879b62d54..754cb2278 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 opencv-python-headless==4.10.0.84 -orbax-checkpoint==0.10.3 +orbax-checkpoint tokenizers==0.21.0 huggingface_hub>=0.30.2 transformers==4.48.1 diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 140938749..ea42e1cd4 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -26,7 +26,7 @@ tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 opencv-python-headless==4.10.0.84 -orbax-checkpoint==0.10.3 +orbax-checkpoint tokenizers==0.21.0 huggingface_hub>=0.30.2 transformers==4.48.1 diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6c..77cb6718e 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -28,7 +28,7 @@ from flax.training import train_state import orbax import orbax.checkpoint as ocp -from orbax.checkpoint.logging import abstract_logger +from orbax.checkpoint.logging import AbstractLogger from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT" @@ -43,7 +43,7 @@ def create_orbax_checkpoint_manager( checkpoint_type: str, dataset_type: str = "tf", use_async: bool = True, - orbax_logger: Optional[abstract_logger.AbstractLogger] = None, + orbax_logger: Optional[AbstractLogger] = None, ): """ Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.