Skip to content

Commit b61144f

Browse files
add a public method to register new DataType in a plugin (#505)
Move DataType to an independent module, so register it will be eaiser. No breaking changes are made. --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7531150 commit b61144f

6 files changed

Lines changed: 182 additions & 117 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Install rdkit
2323
run: python -m pip install rdkit openbabel-wheel
2424
- name: Install dependencies
25-
run: python -m pip install .[amber,ase,pymatgen] coverage
25+
run: python -m pip install .[amber,ase,pymatgen] coverage ./tests/plugin
2626
- name: Test
2727
run: cd tests && coverage run --source=../dpdata -m unittest && cd .. && coverage combine tests/.coverage && coverage report
2828
- name: Run codecov

dpdata/data_type.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from enum import Enum, unique
2+
from typing import TYPE_CHECKING, Tuple
3+
4+
import numpy as np
5+
6+
from dpdata.plugin import Plugin
7+
8+
if TYPE_CHECKING:
9+
from dpdata.system import System
10+
11+
12+
@unique
13+
class Axis(Enum):
14+
"""Data axis."""
15+
16+
NFRAMES = "nframes"
17+
NATOMS = "natoms"
18+
NTYPES = "ntypes"
19+
NBONDS = "nbonds"
20+
21+
22+
class DataError(Exception):
23+
"""Data is not correct."""
24+
25+
26+
class DataType:
27+
"""DataType represents a type of data, like coordinates, energies, etc.
28+
29+
Parameters
30+
----------
31+
name : str
32+
name of data
33+
dtype : type or tuple[type]
34+
data type, e.g. np.ndarray
35+
shape : tuple[int], optional
36+
shape of data. Used when data is list or np.ndarray. Use Axis to
37+
represents numbers
38+
required : bool, default=True
39+
whether this data is required
40+
"""
41+
42+
def __init__(
43+
self,
44+
name: str,
45+
dtype: type,
46+
shape: Tuple[int, Axis] = None,
47+
required: bool = True,
48+
) -> None:
49+
self.name = name
50+
self.dtype = dtype
51+
self.shape = shape
52+
self.required = required
53+
54+
def real_shape(self, system: "System") -> Tuple[int]:
55+
"""Returns expected real shape of a system."""
56+
shape = []
57+
for ii in self.shape:
58+
if ii is Axis.NFRAMES:
59+
shape.append(system.get_nframes())
60+
elif ii is Axis.NTYPES:
61+
shape.append(system.get_ntypes())
62+
elif ii is Axis.NATOMS:
63+
shape.append(system.get_natoms())
64+
elif ii is Axis.NBONDS:
65+
# BondOrderSystem
66+
shape.append(system.get_nbonds())
67+
elif isinstance(ii, int):
68+
shape.append(ii)
69+
else:
70+
raise RuntimeError("Shape is not an int!")
71+
return tuple(shape)
72+
73+
def check(self, system: "System"):
74+
"""Check if a system has correct data of this type.
75+
76+
Parameters
77+
----------
78+
system : System
79+
checked system
80+
81+
Raises
82+
------
83+
DataError
84+
type or shape of data is not correct
85+
"""
86+
# check if exists
87+
if self.name in system.data:
88+
data = system.data[self.name]
89+
# check dtype
90+
# allow list for empty np.ndarray
91+
if isinstance(data, list) and not len(data):
92+
pass
93+
elif not isinstance(data, self.dtype):
94+
raise DataError(
95+
f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}"
96+
)
97+
# check shape
98+
if self.shape is not None:
99+
shape = self.real_shape(system)
100+
# skip checking empty list of np.ndarray
101+
if isinstance(data, np.ndarray):
102+
if data.size and shape != data.shape:
103+
raise DataError(
104+
f"Shape of {self.name} is {data.shape}, but expected {shape}"
105+
)
106+
elif isinstance(data, list):
107+
if len(shape) and shape[0] != len(data):
108+
raise DataError(
109+
"Length of %s is %d, but expected %d"
110+
% (self.name, len(data), shape[0])
111+
)
112+
else:
113+
raise RuntimeError("Unsupported type to check shape")
114+
elif self.required:
115+
raise DataError("%s not found in data" % self.name)
116+
117+
118+
__system_data_type_plugin = Plugin()
119+
__labeled_system_data_type_plugin = Plugin()
120+
121+
122+
def register_data_type(data_type: DataType, labeled: bool):
123+
"""Register a data type.
124+
125+
Parameters
126+
----------
127+
data_type : DataType
128+
data type to be registered
129+
labeled : bool
130+
whether this data type is for LabeledSystem
131+
"""
132+
plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin
133+
plugin.register(data_type.name)(data_type)
134+
135+
136+
def get_data_types(labeled: bool):
137+
"""Get all registered data types.
138+
139+
Parameters
140+
----------
141+
labeled : bool
142+
whether this data type is for LabeledSystem
143+
"""
144+
plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin
145+
return tuple(plugin.plugins.values())

dpdata/system.py

Lines changed: 8 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import glob
33
import os
44
from copy import deepcopy
5-
from enum import Enum, unique
65
from typing import Any, Dict, Optional, Tuple, Union
76

87
import numpy as np
@@ -15,6 +14,7 @@
1514
# ensure all plugins are loaded!
1615
import dpdata.plugins
1716
from dpdata.amber.mask import load_param_file, pick_by_amber_mask
17+
from dpdata.data_type import Axis, DataError, DataType, get_data_types
1818
from dpdata.driver import Driver, Minimizer
1919
from dpdata.format import Format
2020
from dpdata.plugin import Plugin
@@ -33,112 +33,6 @@ def load_format(fmt):
3333
)
3434

3535

36-
@unique
37-
class Axis(Enum):
38-
"""Data axis."""
39-
40-
NFRAMES = "nframes"
41-
NATOMS = "natoms"
42-
NTYPES = "ntypes"
43-
NBONDS = "nbonds"
44-
45-
46-
class DataError(Exception):
47-
"""Data is not correct."""
48-
49-
50-
class DataType:
51-
"""DataType represents a type of data, like coordinates, energies, etc.
52-
53-
Parameters
54-
----------
55-
name : str
56-
name of data
57-
dtype : type or tuple[type]
58-
data type, e.g. np.ndarray
59-
shape : tuple[int], optional
60-
shape of data. Used when data is list or np.ndarray. Use Axis to
61-
represents numbers
62-
required : bool, default=True
63-
whether this data is required
64-
"""
65-
66-
def __init__(
67-
self,
68-
name: str,
69-
dtype: type,
70-
shape: Tuple[int, Axis] = None,
71-
required: bool = True,
72-
) -> None:
73-
self.name = name
74-
self.dtype = dtype
75-
self.shape = shape
76-
self.required = required
77-
78-
def real_shape(self, system: "System") -> Tuple[int]:
79-
"""Returns expected real shape of a system."""
80-
shape = []
81-
for ii in self.shape:
82-
if ii is Axis.NFRAMES:
83-
shape.append(system.get_nframes())
84-
elif ii is Axis.NTYPES:
85-
shape.append(system.get_ntypes())
86-
elif ii is Axis.NATOMS:
87-
shape.append(system.get_natoms())
88-
elif ii is Axis.NBONDS:
89-
# BondOrderSystem
90-
shape.append(system.get_nbonds())
91-
elif isinstance(ii, int):
92-
shape.append(ii)
93-
else:
94-
raise RuntimeError("Shape is not an int!")
95-
return tuple(shape)
96-
97-
def check(self, system: "System"):
98-
"""Check if a system has correct data of this type.
99-
100-
Parameters
101-
----------
102-
system : System
103-
checked system
104-
105-
Raises
106-
------
107-
DataError
108-
type or shape of data is not correct
109-
"""
110-
# check if exists
111-
if self.name in system.data:
112-
data = system.data[self.name]
113-
# check dtype
114-
# allow list for empty np.ndarray
115-
if isinstance(data, list) and not len(data):
116-
pass
117-
elif not isinstance(data, self.dtype):
118-
raise DataError(
119-
f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}"
120-
)
121-
# check shape
122-
if self.shape is not None:
123-
shape = self.real_shape(system)
124-
# skip checking empty list of np.ndarray
125-
if isinstance(data, np.ndarray):
126-
if data.size and shape != data.shape:
127-
raise DataError(
128-
f"Shape of {self.name} is {data.shape}, but expected {shape}"
129-
)
130-
elif isinstance(data, list):
131-
if len(shape) and shape[0] != len(data):
132-
raise DataError(
133-
"Length of %s is %d, but expected %d"
134-
% (self.name, len(data), shape[0])
135-
)
136-
else:
137-
raise RuntimeError("Unsupported type to check shape")
138-
elif self.required:
139-
raise DataError("%s not found in data" % self.name)
140-
141-
14236
class System(MSONable):
14337
"""The data System.
14438
@@ -1657,7 +1551,8 @@ def get_cls_name(cls: object) -> str:
16571551

16581552

16591553
def add_format_methods():
1660-
"""Add format methods to System, LabeledSystem, and MultiSystems.
1554+
"""Add format methods to System, LabeledSystem, and MultiSystems; add data types
1555+
to System and LabeledSystem.
16611556
16621557
Notes
16631558
-----
@@ -1701,5 +1596,10 @@ def to_format(self, *args, **kwargs):
17011596
setattr(LabeledSystem, method, get_func(formatcls))
17021597
setattr(MultiSystems, method, get_func(formatcls))
17031598

1599+
# at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized
1600+
System.DTYPES = System.DTYPES + get_data_types(labeled=False)
1601+
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False)
1602+
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True)
1603+
17041604

17051605
add_format_methods()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import numpy as np
2+
3+
from dpdata.data_type import Axis, DataType, register_data_type
4+
5+
# test data type
6+
7+
register_data_type(
8+
DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True
9+
)
10+
11+
ep = None

tests/plugin/pyproject.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[build-system]
2+
requires = ["setuptools>=61"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "dpdata_plugin_test"
7+
version = "0.0.0"
8+
description = "A test for dpdata plugin"
9+
dependencies = [
10+
'numpy',
11+
'dpdata',
12+
]
13+
readme = "README.md"
14+
requires-python = ">=3.7"
15+
16+
[project.entry-points.'dpdata.plugins']
17+
random = "dpdata_plugin_test:ep"

tests/test_custom_data_type.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,15 @@
44
import numpy as np
55

66
import dpdata
7-
from dpdata.system import Axis, DataType
87

98

109
class TestDeepmdLoadDumpComp(unittest.TestCase):
1110
def setUp(self):
12-
self.backup = dpdata.system.LabeledSystem.DTYPES
13-
dpdata.system.LabeledSystem.DTYPES = dpdata.system.LabeledSystem.DTYPES + (
14-
DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False),
15-
)
1611
self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
1712
self.foo = np.ones((len(self.system), 2, 4))
1813
self.system.data["foo"] = self.foo
1914
self.system.check_data()
2015

21-
def tearDown(self) -> None:
22-
dpdata.system.LabeledSystem.DTYPES = self.backup
23-
2416
def test_to_deepmd_raw(self):
2517
self.system.to_deepmd_raw("data_foo")
2618
foo = np.loadtxt("data_foo/foo.raw")

0 commit comments

Comments
 (0)