Skip to content

Commit 811a133

Browse files
committed
Add test for config conversion for checkpointing
1 parent c5313de commit 811a133

2 files changed

Lines changed: 52 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ri
6060
flash_min_seq_length: 4096
6161
dropout: 0.1
6262

63-
flash_block_sizes: {}
63+
flash_block_sizes: {
64+
"block_q" : 1024,
65+
"block_kv_compute" : 256,
66+
"block_kv" : 1024,
67+
"block_q_dkv" : 1024,
68+
"block_kv_dkv" : 1024,
69+
"block_kv_dkv_compute" : 256,
70+
"block_q_dq" : 1024,
71+
"block_kv_dq" : 1024
72+
}
6473
# Use on v6e
6574
# flash_block_sizes: {
6675
# "block_q" : 3024,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import json
2+
import os
3+
4+
from .. import pyconfig
5+
from ..configuration_utils import ConfigMixin
6+
from .. import __version__
7+
8+
class DummyConfigMixin(ConfigMixin):
9+
config_name = "config.json"
10+
11+
def __init__(self, **kwargs):
12+
self.register_to_config(**kwargs)
13+
14+
def test_to_json_string_with_config():
15+
# Load the YAML config file
16+
config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml")
17+
18+
# Initialize pyconfig with the YAML config
19+
pyconfig.initialize([None, config_path])
20+
config = pyconfig.config
21+
22+
# Create a DummyConfigMixin instance
23+
dummy_config = DummyConfigMixin(**config.get_keys())
24+
25+
# Get the JSON string
26+
json_string = dummy_config.to_json_string()
27+
28+
# Parse the JSON string
29+
parsed_json = json.loads(json_string)
30+
31+
# Assertions
32+
assert parsed_json["_class_name"] == "DummyConfigMixin"
33+
assert parsed_json["_diffusers_version"] == __version__
34+
35+
# Check a few values from the config
36+
assert parsed_json["run_name"] == config.run_name
37+
assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path
38+
assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"]
39+
40+
# The following keys are explicitly removed in to_json_string, so we assert they are not present
41+
assert "weights_dtype" not in parsed_json
42+
assert "precision" not in parsed_json

0 commit comments

Comments
 (0)