Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.48.1
Expand Down
2 changes: 1 addition & 1 deletion requirements_with_jax_ai_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
orbax-checkpoint
tokenizers==0.21.0
huggingface_hub>=0.30.2
transformers==4.48.1
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flax.training import train_state
import orbax
import orbax.checkpoint as ocp
from orbax.checkpoint.logging import abstract_logger
from orbax.checkpoint.logging import AbstractLogger
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions

STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
Expand All @@ -43,7 +43,7 @@ def create_orbax_checkpoint_manager(
checkpoint_type: str,
dataset_type: str = "tf",
use_async: bool = True,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
orbax_logger: Optional[AbstractLogger] = None,
):
"""
Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.
Expand Down
Loading