Skip to content

Commit 47619b7

Browse files
support -1 shape in DataType (#510)
Sometimes, the shape is unknown and can be any integer. This PR supports this situation. The shape of a dimension can be `-1`. Note: Only one `-1` can be used. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9ef8daf commit 47619b7

5 files changed

Lines changed: 58 additions & 2 deletions

File tree

dpdata/data_type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ class Axis(Enum):
1919
NBONDS = "nbonds"
2020

2121

22+
class AnyInt(int):
23+
"""AnyInt equals to any other integer."""
24+
25+
def __eq__(self, other):
26+
return True
27+
28+
2229
class DataError(Exception):
2330
"""Data is not correct."""
2431

@@ -64,6 +71,8 @@ def real_shape(self, system: "System") -> Tuple[int]:
6471
elif ii is Axis.NBONDS:
6572
# BondOrderSystem
6673
shape.append(system.get_nbonds())
74+
elif ii == -1:
75+
shape.append(AnyInt(-1))
6776
elif isinstance(ii, int):
6877
shape.append(ii)
6978
else:

dpdata/deepmd/comp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def to_system_data(folder, type_map=None, labels=True):
8787
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format."
8888
)
8989
continue
90+
natoms = data["coords"].shape[1]
9091
shape = [
91-
-1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]
92+
natoms if xx == dpdata.system.Axis.NATOMS else xx
93+
for xx in dtype.shape[1:]
9294
]
9395
all_data = []
9496
for ii in sets:

dpdata/deepmd/raw.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ def to_system_data(folder, type_map=None, labels=True):
8686
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format."
8787
)
8888
continue
89+
natoms = data["coords"].shape[1]
8990
shape = [
90-
-1 if xx == dpdata.system.Axis.NATOMS else xx
91+
natoms if xx == dpdata.system.Axis.NATOMS else xx
9192
for xx in dtype.shape[1:]
9293
]
9394
if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")):

tests/plugin/dpdata_plugin_test/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,9 @@
88
DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True
99
)
1010

11+
register_data_type(
12+
DataType("bar", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False),
13+
labeled=True,
14+
)
15+
1116
ep = None

tests/test_custom_data_type.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,42 @@ def test_from_deepmd_hdf5(self):
4343
self.system.to_deepmd_hdf5("data_foo.h5")
4444
x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5")
4545
np.testing.assert_allclose(x.data["foo"], self.foo)
46+
47+
48+
class TestDeepmdLoadDumpCompAny(unittest.TestCase):
49+
def setUp(self):
50+
self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
51+
self.bar = np.ones((len(self.system), self.system.get_natoms(), 2))
52+
self.system.data["bar"] = self.bar
53+
self.system.check_data()
54+
55+
def test_to_deepmd_raw(self):
56+
self.system.to_deepmd_raw("data_bar")
57+
bar = np.loadtxt("data_bar/bar.raw")
58+
np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar)
59+
60+
def test_from_deepmd_raw(self):
61+
self.system.to_deepmd_raw("data_bar")
62+
x = dpdata.LabeledSystem("data_bar", fmt="deepmd/raw")
63+
np.testing.assert_allclose(x.data["bar"], self.bar)
64+
65+
def test_to_deepmd_npy(self):
66+
self.system.to_deepmd_npy("data_bar")
67+
bar = np.load("data_bar/set.000/bar.npy")
68+
np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar)
69+
70+
def test_from_deepmd_npy(self):
71+
self.system.to_deepmd_npy("data_bar")
72+
x = dpdata.LabeledSystem("data_bar", fmt="deepmd/npy")
73+
np.testing.assert_allclose(x.data["bar"], self.bar)
74+
75+
def test_to_deepmd_hdf5(self):
76+
self.system.to_deepmd_hdf5("data_bar.h5")
77+
with h5py.File("data_bar.h5") as f:
78+
bar = f["set.000/bar.npy"][:]
79+
np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar)
80+
81+
def test_from_deepmd_hdf5(self):
82+
self.system.to_deepmd_hdf5("data_bar.h5")
83+
x = dpdata.LabeledSystem("data_bar.h5", fmt="deepmd/hdf5")
84+
np.testing.assert_allclose(x.data["bar"], self.bar)

0 commit comments

Comments
 (0)