Skip to content

Commit 390024f

Browse files
committed
qnep small box speedup 2x
1 parent cb49b81 commit 390024f

2 files changed

Lines changed: 77 additions & 62 deletions

File tree

src/force/nep_charge.cu

Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,11 +1737,6 @@ void NEP_Charge::compute_small_box(
17371737

17381738
const int big_neighbor_size = 2000;
17391739
const int size_x12 = type.size() * big_neighbor_size;
1740-
GPU_Vector<int> NN_radial(type.size());
1741-
GPU_Vector<int> NL_radial(size_x12);
1742-
GPU_Vector<int> NN_angular(type.size());
1743-
GPU_Vector<int> NL_angular(size_x12);
1744-
GPU_Vector<float> r12(size_x12 * 6);
17451740

17461741
find_neighbor_list_small_box<<<grid_size, BLOCK_SIZE>>>(
17471742
paramb,
@@ -1754,24 +1749,24 @@ void NEP_Charge::compute_small_box(
17541749
position_per_atom.data(),
17551750
position_per_atom.data() + N,
17561751
position_per_atom.data() + N * 2,
1757-
NN_radial.data(),
1758-
NL_radial.data(),
1759-
NN_angular.data(),
1760-
NL_angular.data(),
1761-
r12.data(),
1762-
r12.data() + size_x12,
1763-
r12.data() + size_x12 * 2,
1764-
r12.data() + size_x12 * 3,
1765-
r12.data() + size_x12 * 4,
1766-
r12.data() + size_x12 * 5);
1752+
small_box_data.NN_radial.data(),
1753+
small_box_data.NL_radial.data(),
1754+
small_box_data.NN_angular.data(),
1755+
small_box_data.NL_angular.data(),
1756+
small_box_data.r12.data(),
1757+
small_box_data.r12.data() + size_x12,
1758+
small_box_data.r12.data() + size_x12 * 2,
1759+
small_box_data.r12.data() + size_x12 * 3,
1760+
small_box_data.r12.data() + size_x12 * 4,
1761+
small_box_data.r12.data() + size_x12 * 5);
17671762
GPU_CHECK_KERNEL
17681763

17691764
static int num_calls = 0;
17701765
if (num_calls++ % 1000 == 0) {
17711766
std::vector<int> cpu_NN_radial(type.size());
17721767
std::vector<int> cpu_NN_angular(type.size());
1773-
NN_radial.copy_to_host(cpu_NN_radial.data());
1774-
NN_angular.copy_to_host(cpu_NN_angular.data());
1768+
small_box_data.NN_radial.copy_to_host(cpu_NN_radial.data());
1769+
small_box_data.NN_angular.copy_to_host(cpu_NN_angular.data());
17751770
int radial_actual = 0;
17761771
int angular_actual = 0;
17771772
for (int n = 0; n < N; ++n) {
@@ -1796,17 +1791,17 @@ void NEP_Charge::compute_small_box(
17961791
N,
17971792
N1,
17981793
N2,
1799-
(paramb.charge_mode >= 3) ? NN_angular.data() : NN_radial.data(),
1800-
(paramb.charge_mode >= 3) ? NL_angular.data() : NL_radial.data(),
1801-
NN_angular.data(),
1802-
NL_angular.data(),
1794+
(paramb.charge_mode >= 3) ? small_box_data.NN_angular.data() : small_box_data.NN_radial.data(),
1795+
(paramb.charge_mode >= 3) ? small_box_data.NL_angular.data() : small_box_data.NL_radial.data(),
1796+
small_box_data.NN_angular.data(),
1797+
small_box_data.NL_angular.data(),
18031798
type.data(),
1804-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 3 : r12.data(),
1805-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 4 : r12.data() + size_x12,
1806-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 5 : r12.data() + size_x12 * 2,
1807-
r12.data() + size_x12 * 3,
1808-
r12.data() + size_x12 * 4,
1809-
r12.data() + size_x12 * 5,
1799+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 3 : small_box_data.r12.data(),
1800+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 4 : small_box_data.r12.data() + size_x12,
1801+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 5 : small_box_data.r12.data() + size_x12 * 2,
1802+
small_box_data.r12.data() + size_x12 * 3,
1803+
small_box_data.r12.data() + size_x12 * 4,
1804+
small_box_data.r12.data() + size_x12 * 5,
18101805
potential_per_atom.data(),
18111806
nep_data.Fp.data(),
18121807
nep_data.charge.data(),
@@ -1836,12 +1831,12 @@ void NEP_Charge::compute_small_box(
18361831
N,
18371832
N1,
18381833
N2,
1839-
(paramb.charge_mode >= 3) ? NN_angular.data() : NN_radial.data(),
1840-
(paramb.charge_mode >= 3) ? NL_angular.data() : NL_radial.data(),
1834+
(paramb.charge_mode >= 3) ? small_box_data.NN_angular.data() : small_box_data.NN_radial.data(),
1835+
(paramb.charge_mode >= 3) ? small_box_data.NL_angular.data() : small_box_data.NL_radial.data(),
18411836
type.data(),
1842-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 3 : r12.data(),
1843-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 4 : r12.data() + size_x12,
1844-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 5 : r12.data() + size_x12 * 2,
1837+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 3 : small_box_data.r12.data(),
1838+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 4 : small_box_data.r12.data() + size_x12,
1839+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 5 : small_box_data.r12.data() + size_x12 * 2,
18451840
nep_data.charge_derivative.data(),
18461841
nep_data.bec.data());
18471842
GPU_CHECK_KERNEL
@@ -1853,12 +1848,12 @@ void NEP_Charge::compute_small_box(
18531848
N,
18541849
N1,
18551850
N2,
1856-
NN_angular.data(),
1857-
NL_angular.data(),
1851+
small_box_data.NN_angular.data(),
1852+
small_box_data.NL_angular.data(),
18581853
type.data(),
1859-
r12.data() + size_x12 * 3,
1860-
r12.data() + size_x12 * 4,
1861-
r12.data() + size_x12 * 5,
1854+
small_box_data.r12.data() + size_x12 * 3,
1855+
small_box_data.r12.data() + size_x12 * 4,
1856+
small_box_data.r12.data() + size_x12 * 5,
18621857
nep_data.charge_derivative.data(),
18631858
nep_data.sum_fxyz.data(),
18641859
nep_data.bec.data());
@@ -1905,12 +1900,12 @@ void NEP_Charge::compute_small_box(
19051900
N1,
19061901
N2,
19071902
box,
1908-
NN_radial.data(),
1909-
NL_radial.data(),
1903+
small_box_data.NN_radial.data(),
1904+
small_box_data.NL_radial.data(),
19101905
nep_data.charge.data(),
1911-
r12.data(),
1912-
r12.data() + size_x12,
1913-
r12.data() + size_x12 * 2,
1906+
small_box_data.r12.data(),
1907+
small_box_data.r12.data() + size_x12,
1908+
small_box_data.r12.data() + size_x12 * 2,
19141909
force_per_atom.data(),
19151910
force_per_atom.data() + N,
19161911
force_per_atom.data() + N * 2,
@@ -1928,12 +1923,12 @@ void NEP_Charge::compute_small_box(
19281923
N1,
19291924
N2,
19301925
box,
1931-
NN_radial.data(),
1932-
NL_radial.data(),
1926+
small_box_data.NN_radial.data(),
1927+
small_box_data.NL_radial.data(),
19331928
nep_data.C6.data(),
1934-
r12.data(),
1935-
r12.data() + size_x12,
1936-
r12.data() + size_x12 * 2,
1929+
small_box_data.r12.data(),
1930+
small_box_data.r12.data() + size_x12,
1931+
small_box_data.r12.data() + size_x12 * 2,
19371932
force_per_atom.data(),
19381933
force_per_atom.data() + N,
19391934
force_per_atom.data() + N * 2,
@@ -1949,12 +1944,12 @@ void NEP_Charge::compute_small_box(
19491944
N,
19501945
N1,
19511946
N2,
1952-
(paramb.charge_mode >= 3) ? NN_angular.data() : NN_radial.data(),
1953-
(paramb.charge_mode >= 3) ? NL_angular.data() : NL_radial.data(),
1947+
(paramb.charge_mode >= 3) ? small_box_data.NN_angular.data() : small_box_data.NN_radial.data(),
1948+
(paramb.charge_mode >= 3) ? small_box_data.NL_angular.data() : small_box_data.NL_radial.data(),
19541949
type.data(),
1955-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 3 : r12.data(),
1956-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 4 : r12.data() + size_x12,
1957-
(paramb.charge_mode >= 3) ? r12.data() + size_x12 * 5 : r12.data() + size_x12 * 2,
1950+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 3 : small_box_data.r12.data(),
1951+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 4 : small_box_data.r12.data() + size_x12,
1952+
(paramb.charge_mode >= 3) ? small_box_data.r12.data() + size_x12 * 5 : small_box_data.r12.data() + size_x12 * 2,
19581953
nep_data.Fp.data(),
19591954
nep_data.charge_derivative.data(),
19601955
nep_data.D_real.data(),
@@ -1972,12 +1967,12 @@ void NEP_Charge::compute_small_box(
19721967
N,
19731968
N1,
19741969
N2,
1975-
NN_angular.data(),
1976-
NL_angular.data(),
1970+
small_box_data.NN_angular.data(),
1971+
small_box_data.NL_angular.data(),
19771972
type.data(),
1978-
r12.data() + size_x12 * 3,
1979-
r12.data() + size_x12 * 4,
1980-
r12.data() + size_x12 * 5,
1973+
small_box_data.r12.data() + size_x12 * 3,
1974+
small_box_data.r12.data() + size_x12 * 4,
1975+
small_box_data.r12.data() + size_x12 * 5,
19811976
nep_data.Fp.data(),
19821977
nep_data.charge_derivative.data(),
19831978
nep_data.D_real.data(),
@@ -1997,12 +1992,12 @@ void NEP_Charge::compute_small_box(
19971992
zbl,
19981993
N1,
19991994
N2,
2000-
NN_angular.data(),
2001-
NL_angular.data(),
1995+
small_box_data.NN_angular.data(),
1996+
small_box_data.NL_angular.data(),
20021997
type.data(),
2003-
r12.data() + size_x12 * 3,
2004-
r12.data() + size_x12 * 4,
2005-
r12.data() + size_x12 * 5,
1998+
small_box_data.r12.data() + size_x12 * 3,
1999+
small_box_data.r12.data() + size_x12 * 4,
2000+
small_box_data.r12.data() + size_x12 * 5,
20062001
force_per_atom.data(),
20072002
force_per_atom.data() + N,
20082003
force_per_atom.data() + N * 2,
@@ -2087,6 +2082,18 @@ void NEP_Charge::compute(
20872082

20882083
const bool is_small_box = get_expanded_box(paramb.rc_radial, box, ebox);
20892084
if (is_small_box) {
2085+
// update small_box_data
2086+
const int current_num_atoms = type.size();
2087+
if (small_box_data.NN_radial.size() != current_num_atoms) {
2088+
const int big_neighbor_size = 2000;
2089+
const int size_x12 = current_num_atoms * big_neighbor_size;
2090+
2091+
small_box_data.NN_radial.resize(current_num_atoms);
2092+
small_box_data.NL_radial.resize(size_x12);
2093+
small_box_data.NN_angular.resize(current_num_atoms);
2094+
small_box_data.NL_angular.resize(size_x12);
2095+
small_box_data.r12.resize(size_x12 * 6);
2096+
}
20902097
compute_small_box(
20912098
box, type, position_per_atom, potential_per_atom, force_per_atom, virial_per_atom);
20922099
} else {

src/force/nep_charge.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ public:
110110
float h[18];
111111
};
112112

113+
struct Small_Box_Data {
114+
GPU_Vector<int> NN_radial;
115+
GPU_Vector<int> NL_radial;
116+
GPU_Vector<int> NN_angular;
117+
GPU_Vector<int> NL_angular;
118+
GPU_Vector<float> r12;
119+
} small_box_data;
120+
113121
struct Charge_Para {
114122
int num_kpoints_max = 1;
115123
float alpha = 0.5f; // 1 / (2 Angstrom)

0 commit comments

Comments
 (0)