Skip to content

Commit 8326582

Browse files
Unpinned orbax and updated orbax logger API
1 parent 4e0999c commit 8326582

3 files changed

Lines changed: 4 additions & 4 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ tensorflow>=2.17.0
2424
tensorflow-datasets>=4.9.6
2525
ruff>=0.1.5,<=0.2
2626
opencv-python-headless==4.10.0.84
27-
orbax-checkpoint==0.10.3
27+
orbax-checkpoint
2828
tokenizers==0.21.0
2929
huggingface_hub>=0.30.2
3030
transformers==4.48.1

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ tensorflow>=2.17.0
2626
tensorflow-datasets>=4.9.6
2727
ruff>=0.1.5,<=0.2
2828
opencv-python-headless==4.10.0.84
29-
orbax-checkpoint==0.10.3
29+
orbax-checkpoint
3030
tokenizers==0.21.0
3131
huggingface_hub>=0.30.2
3232
transformers==4.48.1

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax.training import train_state
2929
import orbax
3030
import orbax.checkpoint as ocp
31-
from orbax.checkpoint.logging import abstract_logger
31+
from orbax.checkpoint.logging import AbstractLogger
3232
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
@@ -43,7 +43,7 @@ def create_orbax_checkpoint_manager(
4343
checkpoint_type: str,
4444
dataset_type: str = "tf",
4545
use_async: bool = True,
46-
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
46+
orbax_logger: Optional[AbstractLogger] = None,
4747
):
4848
"""
4949
Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.

0 commit comments

Comments
 (0)