22import glob
33import os
44from copy import deepcopy
5- from enum import Enum , unique
65from typing import Any , Dict , Optional , Tuple , Union
76
87import numpy as np
1514# ensure all plugins are loaded!
1615import dpdata .plugins
1716from dpdata .amber .mask import load_param_file , pick_by_amber_mask
17+ from dpdata .data_type import Axis , DataError , DataType , get_data_types
1818from dpdata .driver import Driver , Minimizer
1919from dpdata .format import Format
2020from dpdata .plugin import Plugin
@@ -33,112 +33,6 @@ def load_format(fmt):
3333 )
3434
3535
36- @unique
37- class Axis (Enum ):
38- """Data axis."""
39-
40- NFRAMES = "nframes"
41- NATOMS = "natoms"
42- NTYPES = "ntypes"
43- NBONDS = "nbonds"
44-
45-
46- class DataError (Exception ):
47- """Data is not correct."""
48-
49-
50- class DataType :
51- """DataType represents a type of data, like coordinates, energies, etc.
52-
53- Parameters
54- ----------
55- name : str
56- name of data
57- dtype : type or tuple[type]
58- data type, e.g. np.ndarray
59- shape : tuple[int], optional
60- shape of data. Used when data is list or np.ndarray. Use Axis to
61- represents numbers
62- required : bool, default=True
63- whether this data is required
64- """
65-
66- def __init__ (
67- self ,
68- name : str ,
69- dtype : type ,
70- shape : Tuple [int , Axis ] = None ,
71- required : bool = True ,
72- ) -> None :
73- self .name = name
74- self .dtype = dtype
75- self .shape = shape
76- self .required = required
77-
78- def real_shape (self , system : "System" ) -> Tuple [int ]:
79- """Returns expected real shape of a system."""
80- shape = []
81- for ii in self .shape :
82- if ii is Axis .NFRAMES :
83- shape .append (system .get_nframes ())
84- elif ii is Axis .NTYPES :
85- shape .append (system .get_ntypes ())
86- elif ii is Axis .NATOMS :
87- shape .append (system .get_natoms ())
88- elif ii is Axis .NBONDS :
89- # BondOrderSystem
90- shape .append (system .get_nbonds ())
91- elif isinstance (ii , int ):
92- shape .append (ii )
93- else :
94- raise RuntimeError ("Shape is not an int!" )
95- return tuple (shape )
96-
97- def check (self , system : "System" ):
98- """Check if a system has correct data of this type.
99-
100- Parameters
101- ----------
102- system : System
103- checked system
104-
105- Raises
106- ------
107- DataError
108- type or shape of data is not correct
109- """
110- # check if exists
111- if self .name in system .data :
112- data = system .data [self .name ]
113- # check dtype
114- # allow list for empty np.ndarray
115- if isinstance (data , list ) and not len (data ):
116- pass
117- elif not isinstance (data , self .dtype ):
118- raise DataError (
119- f"Type of { self .name } is { type (data ).__name__ } , but expected { self .dtype .__name__ } "
120- )
121- # check shape
122- if self .shape is not None :
123- shape = self .real_shape (system )
124- # skip checking empty list of np.ndarray
125- if isinstance (data , np .ndarray ):
126- if data .size and shape != data .shape :
127- raise DataError (
128- f"Shape of { self .name } is { data .shape } , but expected { shape } "
129- )
130- elif isinstance (data , list ):
131- if len (shape ) and shape [0 ] != len (data ):
132- raise DataError (
133- "Length of %s is %d, but expected %d"
134- % (self .name , len (data ), shape [0 ])
135- )
136- else :
137- raise RuntimeError ("Unsupported type to check shape" )
138- elif self .required :
139- raise DataError ("%s not found in data" % self .name )
140-
141-
14236class System (MSONable ):
14337 """The data System.
14438
@@ -1657,7 +1551,8 @@ def get_cls_name(cls: object) -> str:
16571551
16581552
16591553def add_format_methods ():
1660- """Add format methods to System, LabeledSystem, and MultiSystems.
1554+ """Add format methods to System, LabeledSystem, and MultiSystems; add data types
1555+ to System and LabeledSystem.
16611556
16621557 Notes
16631558 -----
@@ -1701,5 +1596,10 @@ def to_format(self, *args, **kwargs):
17011596 setattr (LabeledSystem , method , get_func (formatcls ))
17021597 setattr (MultiSystems , method , get_func (formatcls ))
17031598
1599+ # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized
1600+ System .DTYPES = System .DTYPES + get_data_types (labeled = False )
1601+ LabeledSystem .DTYPES = LabeledSystem .DTYPES + get_data_types (labeled = False )
1602+ LabeledSystem .DTYPES = LabeledSystem .DTYPES + get_data_types (labeled = True )
1603+
17041604
17051605add_format_methods ()
0 commit comments