@@ -1737,11 +1737,6 @@ void NEP_Charge::compute_small_box(
17371737
17381738 const int big_neighbor_size = 2000 ;
17391739 const int size_x12 = type.size () * big_neighbor_size;
1740- GPU_Vector<int > NN_radial (type.size ());
1741- GPU_Vector<int > NL_radial (size_x12);
1742- GPU_Vector<int > NN_angular (type.size ());
1743- GPU_Vector<int > NL_angular (size_x12);
1744- GPU_Vector<float > r12 (size_x12 * 6 );
17451740
17461741 find_neighbor_list_small_box<<<grid_size, BLOCK_SIZE>>> (
17471742 paramb,
@@ -1754,24 +1749,24 @@ void NEP_Charge::compute_small_box(
17541749 position_per_atom.data (),
17551750 position_per_atom.data () + N,
17561751 position_per_atom.data () + N * 2 ,
1757- NN_radial.data (),
1758- NL_radial.data (),
1759- NN_angular.data (),
1760- NL_angular.data (),
1761- r12.data (),
1762- r12.data () + size_x12,
1763- r12.data () + size_x12 * 2 ,
1764- r12.data () + size_x12 * 3 ,
1765- r12.data () + size_x12 * 4 ,
1766- r12.data () + size_x12 * 5 );
1752+ small_box_data. NN_radial .data (),
1753+ small_box_data. NL_radial .data (),
1754+ small_box_data. NN_angular .data (),
1755+ small_box_data. NL_angular .data (),
1756+ small_box_data. r12 .data (),
1757+ small_box_data. r12 .data () + size_x12,
1758+ small_box_data. r12 .data () + size_x12 * 2 ,
1759+ small_box_data. r12 .data () + size_x12 * 3 ,
1760+ small_box_data. r12 .data () + size_x12 * 4 ,
1761+ small_box_data. r12 .data () + size_x12 * 5 );
17671762 GPU_CHECK_KERNEL
17681763
17691764 static int num_calls = 0 ;
17701765 if (num_calls++ % 1000 == 0 ) {
17711766 std::vector<int > cpu_NN_radial (type.size ());
17721767 std::vector<int > cpu_NN_angular (type.size ());
1773- NN_radial.copy_to_host (cpu_NN_radial.data ());
1774- NN_angular.copy_to_host (cpu_NN_angular.data ());
1768+ small_box_data. NN_radial .copy_to_host (cpu_NN_radial.data ());
1769+ small_box_data. NN_angular .copy_to_host (cpu_NN_angular.data ());
17751770 int radial_actual = 0 ;
17761771 int angular_actual = 0 ;
17771772 for (int n = 0 ; n < N; ++n) {
@@ -1796,17 +1791,17 @@ void NEP_Charge::compute_small_box(
17961791 N,
17971792 N1,
17981793 N2,
1799- (paramb.charge_mode >= 3 ) ? NN_angular.data () : NN_radial.data (),
1800- (paramb.charge_mode >= 3 ) ? NL_angular.data () : NL_radial.data (),
1801- NN_angular.data (),
1802- NL_angular.data (),
1794+ (paramb.charge_mode >= 3 ) ? small_box_data. NN_angular .data () : small_box_data. NN_radial .data (),
1795+ (paramb.charge_mode >= 3 ) ? small_box_data. NL_angular .data () : small_box_data. NL_radial .data (),
1796+ small_box_data. NN_angular .data (),
1797+ small_box_data. NL_angular .data (),
18031798 type.data (),
1804- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 3 : r12.data (),
1805- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 4 : r12.data () + size_x12,
1806- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 5 : r12.data () + size_x12 * 2 ,
1807- r12.data () + size_x12 * 3 ,
1808- r12.data () + size_x12 * 4 ,
1809- r12.data () + size_x12 * 5 ,
1799+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 3 : small_box_data. r12 .data (),
1800+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 4 : small_box_data. r12 .data () + size_x12,
1801+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 5 : small_box_data. r12 .data () + size_x12 * 2 ,
1802+ small_box_data. r12 .data () + size_x12 * 3 ,
1803+ small_box_data. r12 .data () + size_x12 * 4 ,
1804+ small_box_data. r12 .data () + size_x12 * 5 ,
18101805 potential_per_atom.data (),
18111806 nep_data.Fp .data (),
18121807 nep_data.charge .data (),
@@ -1836,12 +1831,12 @@ void NEP_Charge::compute_small_box(
18361831 N,
18371832 N1,
18381833 N2,
1839- (paramb.charge_mode >= 3 ) ? NN_angular.data () : NN_radial.data (),
1840- (paramb.charge_mode >= 3 ) ? NL_angular.data () : NL_radial.data (),
1834+ (paramb.charge_mode >= 3 ) ? small_box_data. NN_angular .data () : small_box_data. NN_radial .data (),
1835+ (paramb.charge_mode >= 3 ) ? small_box_data. NL_angular .data () : small_box_data. NL_radial .data (),
18411836 type.data (),
1842- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 3 : r12.data (),
1843- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 4 : r12.data () + size_x12,
1844- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 5 : r12.data () + size_x12 * 2 ,
1837+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 3 : small_box_data. r12 .data (),
1838+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 4 : small_box_data. r12 .data () + size_x12,
1839+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 5 : small_box_data. r12 .data () + size_x12 * 2 ,
18451840 nep_data.charge_derivative .data (),
18461841 nep_data.bec .data ());
18471842 GPU_CHECK_KERNEL
@@ -1853,12 +1848,12 @@ void NEP_Charge::compute_small_box(
18531848 N,
18541849 N1,
18551850 N2,
1856- NN_angular.data (),
1857- NL_angular.data (),
1851+ small_box_data. NN_angular .data (),
1852+ small_box_data. NL_angular .data (),
18581853 type.data (),
1859- r12.data () + size_x12 * 3 ,
1860- r12.data () + size_x12 * 4 ,
1861- r12.data () + size_x12 * 5 ,
1854+ small_box_data. r12 .data () + size_x12 * 3 ,
1855+ small_box_data. r12 .data () + size_x12 * 4 ,
1856+ small_box_data. r12 .data () + size_x12 * 5 ,
18621857 nep_data.charge_derivative .data (),
18631858 nep_data.sum_fxyz .data (),
18641859 nep_data.bec .data ());
@@ -1905,12 +1900,12 @@ void NEP_Charge::compute_small_box(
19051900 N1,
19061901 N2,
19071902 box,
1908- NN_radial.data (),
1909- NL_radial.data (),
1903+ small_box_data. NN_radial .data (),
1904+ small_box_data. NL_radial .data (),
19101905 nep_data.charge .data (),
1911- r12.data (),
1912- r12.data () + size_x12,
1913- r12.data () + size_x12 * 2 ,
1906+ small_box_data. r12 .data (),
1907+ small_box_data. r12 .data () + size_x12,
1908+ small_box_data. r12 .data () + size_x12 * 2 ,
19141909 force_per_atom.data (),
19151910 force_per_atom.data () + N,
19161911 force_per_atom.data () + N * 2 ,
@@ -1928,12 +1923,12 @@ void NEP_Charge::compute_small_box(
19281923 N1,
19291924 N2,
19301925 box,
1931- NN_radial.data (),
1932- NL_radial.data (),
1926+ small_box_data. NN_radial .data (),
1927+ small_box_data. NL_radial .data (),
19331928 nep_data.C6 .data (),
1934- r12.data (),
1935- r12.data () + size_x12,
1936- r12.data () + size_x12 * 2 ,
1929+ small_box_data. r12 .data (),
1930+ small_box_data. r12 .data () + size_x12,
1931+ small_box_data. r12 .data () + size_x12 * 2 ,
19371932 force_per_atom.data (),
19381933 force_per_atom.data () + N,
19391934 force_per_atom.data () + N * 2 ,
@@ -1949,12 +1944,12 @@ void NEP_Charge::compute_small_box(
19491944 N,
19501945 N1,
19511946 N2,
1952- (paramb.charge_mode >= 3 ) ? NN_angular.data () : NN_radial.data (),
1953- (paramb.charge_mode >= 3 ) ? NL_angular.data () : NL_radial.data (),
1947+ (paramb.charge_mode >= 3 ) ? small_box_data. NN_angular .data () : small_box_data. NN_radial .data (),
1948+ (paramb.charge_mode >= 3 ) ? small_box_data. NL_angular .data () : small_box_data. NL_radial .data (),
19541949 type.data (),
1955- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 3 : r12.data (),
1956- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 4 : r12.data () + size_x12,
1957- (paramb.charge_mode >= 3 ) ? r12.data () + size_x12 * 5 : r12.data () + size_x12 * 2 ,
1950+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 3 : small_box_data. r12 .data (),
1951+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 4 : small_box_data. r12 .data () + size_x12,
1952+ (paramb.charge_mode >= 3 ) ? small_box_data. r12 .data () + size_x12 * 5 : small_box_data. r12 .data () + size_x12 * 2 ,
19581953 nep_data.Fp .data (),
19591954 nep_data.charge_derivative .data (),
19601955 nep_data.D_real .data (),
@@ -1972,12 +1967,12 @@ void NEP_Charge::compute_small_box(
19721967 N,
19731968 N1,
19741969 N2,
1975- NN_angular.data (),
1976- NL_angular.data (),
1970+ small_box_data. NN_angular .data (),
1971+ small_box_data. NL_angular .data (),
19771972 type.data (),
1978- r12.data () + size_x12 * 3 ,
1979- r12.data () + size_x12 * 4 ,
1980- r12.data () + size_x12 * 5 ,
1973+ small_box_data. r12 .data () + size_x12 * 3 ,
1974+ small_box_data. r12 .data () + size_x12 * 4 ,
1975+ small_box_data. r12 .data () + size_x12 * 5 ,
19811976 nep_data.Fp .data (),
19821977 nep_data.charge_derivative .data (),
19831978 nep_data.D_real .data (),
@@ -1997,12 +1992,12 @@ void NEP_Charge::compute_small_box(
19971992 zbl,
19981993 N1,
19991994 N2,
2000- NN_angular.data (),
2001- NL_angular.data (),
1995+ small_box_data. NN_angular .data (),
1996+ small_box_data. NL_angular .data (),
20021997 type.data (),
2003- r12.data () + size_x12 * 3 ,
2004- r12.data () + size_x12 * 4 ,
2005- r12.data () + size_x12 * 5 ,
1998+ small_box_data. r12 .data () + size_x12 * 3 ,
1999+ small_box_data. r12 .data () + size_x12 * 4 ,
2000+ small_box_data. r12 .data () + size_x12 * 5 ,
20062001 force_per_atom.data (),
20072002 force_per_atom.data () + N,
20082003 force_per_atom.data () + N * 2 ,
@@ -2087,6 +2082,18 @@ void NEP_Charge::compute(
20872082
20882083 const bool is_small_box = get_expanded_box (paramb.rc_radial , box, ebox);
20892084 if (is_small_box) {
2085+ // update small_box_data
2086+ const int current_num_atoms = type.size ();
2087+ if (small_box_data.NN_radial .size () != current_num_atoms) {
2088+ const int big_neighbor_size = 2000 ;
2089+ const int size_x12 = current_num_atoms * big_neighbor_size;
2090+
2091+ small_box_data.NN_radial .resize (current_num_atoms);
2092+ small_box_data.NL_radial .resize (size_x12);
2093+ small_box_data.NN_angular .resize (current_num_atoms);
2094+ small_box_data.NL_angular .resize (size_x12);
2095+ small_box_data.r12 .resize (size_x12 * 6 );
2096+ }
20902097 compute_small_box (
20912098 box, type, position_per_atom, potential_per_atom, force_per_atom, virial_per_atom);
20922099 } else {
0 commit comments