Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions maxdiffusion_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Use Python 3.10-slim-bullseye as the base image
FROM python:3.10-slim-bullseye
# Use Python 3.12-slim-bullseye as the base image
FROM python:3.12-slim-bullseye

# Environment variable for no-cache-dir and pip root user warning
ENV PIP_NO_CACHE_DIR=1
ENV PIP_ROOT_USER_ACTION=ignore

# Set environment variables for Google Cloud SDK and Python 3.10
ENV PYTHON_VERSION=3.10
ENV PYTHON_VERSION=3.12
ENV CLOUD_SDK_VERSION=latest

# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.48.1
Expand Down
2 changes: 1 addition & 1 deletion requirements_with_jax_ai_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.48.1
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flax.training import train_state
import orbax
import orbax.checkpoint as ocp
from orbax.checkpoint.logging import abstract_logger
from orbax.checkpoint.logging import AbstractLogger
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions

STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
Expand All @@ -43,7 +43,7 @@ def create_orbax_checkpoint_manager(
checkpoint_type: str,
dataset_type: str = "tf",
use_async: bool = True,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
orbax_logger: Optional[AbstractLogger] = None,
):
"""
Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.
Expand Down
Loading