Skip to content

Commit 4745a70

Browse files
committed
Add test for config conversion for checkpointing
1 parent c5313de commit 4745a70

2 files changed

Lines changed: 53 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ri
6060
flash_min_seq_length: 4096
6161
dropout: 0.1
6262

63-
flash_block_sizes: {}
63+
flash_block_sizes: {
64+
"block_q" : 1024,
65+
"block_kv_compute" : 256,
66+
"block_kv" : 1024,
67+
"block_q_dkv" : 1024,
68+
"block_kv_dkv" : 1024,
69+
"block_kv_dkv_compute" : 256,
70+
"block_q_dq" : 1024,
71+
"block_kv_dq" : 1024
72+
}
6473
# Use on v6e
6574
# flash_block_sizes: {
6675
# "block_q" : 3024,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
import os
3+
import yaml
4+
5+
from .. import pyconfig
6+
from ..configuration_utils import ConfigMixin
7+
from .. import __version__
8+
9+
class DummyConfigMixin(ConfigMixin):
10+
config_name = "config.json"
11+
12+
def __init__(self, **kwargs):
13+
self.register_to_config(**kwargs)
14+
15+
def test_to_json_string_with_config():
16+
# Load the YAML config file
17+
config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml")
18+
19+
# Initialize pyconfig with the YAML config
20+
pyconfig.initialize([None, config_path])
21+
config = pyconfig.config
22+
23+
# Create a DummyConfigMixin instance
24+
dummy_config = DummyConfigMixin(**config.get_keys())
25+
26+
# Get the JSON string
27+
json_string = dummy_config.to_json_string()
28+
29+
# Parse the JSON string
30+
parsed_json = json.loads(json_string)
31+
32+
# Assertions
33+
assert parsed_json["_class_name"] == "DummyConfigMixin"
34+
assert parsed_json["_diffusers_version"] == __version__
35+
36+
# Check a few values from the config
37+
assert parsed_json["run_name"] == config.run_name
38+
assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path
39+
assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"]
40+
41+
# The following keys are explicitly removed in to_json_string, so we assert they are not present
42+
assert "weights_dtype" not in parsed_json
43+
assert "precision" not in parsed_json

0 commit comments

Comments
 (0)