@@ -12,13 +12,17 @@ class NormalGenerator:
1212 def __init__ (self ):
1313 self .randn_generator = self .get_randn_generator ()
1414 self .rand_generator = self .get_rand_generator ()
15+ self .choice_generator = self .get_choice_generator ()
1516
1617 def randn (self , number ):
1718 return next (self .randn_generator )
1819
1920 def rand (self , number ):
2021 return next (self .rand_generator )
2122
23+ def choice (self , total_natoms , pert_natoms , replace ):
24+ return next (self .choice_generator )[:pert_natoms ]
25+
2226 @staticmethod
2327 def get_randn_generator ():
2428 data = np .asarray (
@@ -44,18 +48,26 @@ def get_rand_generator():
4448 [0.23182233 , 0.87106847 , 0.68728511 , 0.94180274 , 0.92860453 , 0.69191187 ]
4549 )
4650
51+ @staticmethod
52+ def get_choice_generator ():
53+ yield np .asarray ([5 , 3 , 7 , 6 , 2 , 1 , 4 , 0 ])
54+
4755
4856class UniformGenerator :
4957 def __init__ (self ):
5058 self .randn_generator = self .get_randn_generator ()
5159 self .rand_generator = self .get_rand_generator ()
60+ self .choice_generator = self .get_choice_generator ()
5261
5362 def randn (self , number ):
5463 return next (self .randn_generator )
5564
5665 def rand (self , number ):
5766 return next (self .rand_generator )
5867
68+ def choice (self , total_natoms , pert_natoms , replace ):
69+ return next (self .choice_generator )
70+
5971 @staticmethod
6072 def get_randn_generator ():
6173 data = [
@@ -97,18 +109,26 @@ def get_rand_generator():
97109 yield np .asarray (data [count ])
98110 count += 1
99111
112+ @staticmethod
113+ def get_choice_generator ():
114+ yield np .asarray ([5 , 3 , 7 , 6 , 2 , 1 , 4 , 0 ])
115+
100116
101117class ConstGenerator :
102118 def __init__ (self ):
103119 self .randn_generator = self .get_randn_generator ()
104120 self .rand_generator = self .get_rand_generator ()
121+ self .choice_generator = self .get_choice_generator ()
105122
106123 def randn (self , number ):
107124 return next (self .randn_generator )
108125
109126 def rand (self , number ):
110127 return next (self .rand_generator )
111128
129+ def choice (self , total_natoms , pert_natoms , replace ):
130+ return next (self .choice_generator )
131+
112132 @staticmethod
113133 def get_randn_generator ():
114134 data = np .asarray (
@@ -135,13 +155,18 @@ def get_rand_generator():
135155 [0.01525907 , 0.68387374 , 0.39768541 , 0.55596047 , 0.26557088 , 0.60883073 ]
136156 )
137157
158+ @staticmethod
159+ def get_choice_generator ():
160+ yield np .asarray ([5 , 3 , 7 , 6 , 2 , 1 , 4 , 0 ])
161+
138162
139163# %%
140164class TestPerturbNormal (unittest .TestCase , CompSys , IsPBC ):
141165 @patch ("numpy.random" )
142166 def setUp (self , random_mock ):
143167 random_mock .rand = NormalGenerator ().rand
144168 random_mock .randn = NormalGenerator ().randn
169+ random_mock .choice = NormalGenerator ().choice
145170 system_1_origin = dpdata .System ("poscars/POSCAR.SiC" , fmt = "vasp/poscar" )
146171 self .system_1 = system_1_origin .perturb (1 , 0.05 , 0.6 , "normal" )
147172 self .system_2 = dpdata .System ("poscars/POSCAR.SiC.normal" , fmt = "vasp/poscar" )
@@ -153,6 +178,7 @@ class TestPerturbUniform(unittest.TestCase, CompSys, IsPBC):
153178 def setUp (self , random_mock ):
154179 random_mock .rand = UniformGenerator ().rand
155180 random_mock .randn = UniformGenerator ().randn
181+ random_mock .choice = UniformGenerator ().choice
156182 system_1_origin = dpdata .System ("poscars/POSCAR.SiC" , fmt = "vasp/poscar" )
157183 self .system_1 = system_1_origin .perturb (1 , 0.05 , 0.6 , "uniform" )
158184 self .system_2 = dpdata .System ("poscars/POSCAR.SiC.uniform" , fmt = "vasp/poscar" )
@@ -164,11 +190,24 @@ class TestPerturbConst(unittest.TestCase, CompSys, IsPBC):
164190 def setUp (self , random_mock ):
165191 random_mock .rand = ConstGenerator ().rand
166192 random_mock .randn = ConstGenerator ().randn
193+ random_mock .choice = ConstGenerator ().choice
167194 system_1_origin = dpdata .System ("poscars/POSCAR.SiC" , fmt = "vasp/poscar" )
168195 self .system_1 = system_1_origin .perturb (1 , 0.05 , 0.6 , "const" )
169196 self .system_2 = dpdata .System ("poscars/POSCAR.SiC.const" , fmt = "vasp/poscar" )
170197 self .places = 6
171198
172199
200+ class TestPerturbPartAtoms (unittest .TestCase , CompSys , IsPBC ):
201+ @patch ("numpy.random" )
202+ def setUp (self , random_mock ):
203+ random_mock .rand = NormalGenerator ().rand
204+ random_mock .randn = NormalGenerator ().randn
205+ random_mock .choice = NormalGenerator ().choice
206+ system_1_origin = dpdata .System ("poscars/POSCAR.SiC" , fmt = "vasp/poscar" )
207+ self .system_1 = system_1_origin .perturb (1 , 0.05 , 0.6 , "normal" , 0.25 )
208+ self .system_2 = dpdata .System ("poscars/POSCAR.SiC.partpert" , fmt = "vasp/poscar" )
209+ self .places = 6
210+
211+
173212if __name__ == "__main__" :
174213 unittest .main ()
0 commit comments