Skip to content

Commit c5313de

Browse files
committed
Conversion of dataclass objects and others, not raise error
1 parent fd53cde commit c5313de

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/maxdiffusion/configuration_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import Any, Dict, Tuple, Union
2727
from . import max_logging
2828
import numpy as np
29+
from dataclasses import asdict, is_dataclass
2930

3031
from huggingface_hub import create_repo, hf_hub_download
3132
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
@@ -54,16 +55,17 @@ class CustomEncoder(json.JSONEncoder):
5455
"""
5556

5657
def default(self, o):
57-
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
5858
if isinstance(o, type(jnp.dtype("bfloat16"))):
5959
return str(o)
60-
# Add fallbacks for other numpy types if needed
6160
if isinstance(o, np.integer):
6261
return int(o)
6362
if isinstance(o, np.floating):
6463
return float(o)
65-
# Let the base class default method raise the TypeError for other types
66-
return super().default(o)
64+
if is_dataclass(o):
65+
return asdict(o)
66+
else:
67+
max_logging.log(f"Warning: {o} of type {type(o)} is not JSON serializable")
68+
return None
6769

6870

6971
class FrozenDict(OrderedDict):

0 commit comments

Comments
 (0)