Skip to content

Commit cb49b81

Browse files
committed
qnep box
1 parent fcd2fa5 commit cb49b81

2 files changed

Lines changed: 67 additions & 70 deletions

File tree

src/force/nep_charge.cu

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,10 @@ static __global__ void find_neighbor_list_large_box(
467467
continue;
468468
}
469469

470-
double x12double = g_x[n2] - x1;
471-
double y12double = g_y[n2] - y1;
472-
double z12double = g_z[n2] - z1;
473-
apply_mic(box, x12double, y12double, z12double);
474-
float x12 = float(x12double), y12 = float(y12double), z12 = float(z12double);
470+
float x12 = g_x[n2] - x1;
471+
float y12 = g_y[n2] - y1;
472+
float z12 = g_z[n2] - z1;
473+
apply_mic(box, x12, y12, z12);
475474
float d12_square = x12 * x12 + y12 * y12 + z12 * z12;
476475

477476
float rc_radial = paramb.rc_radial;
@@ -530,11 +529,10 @@ static __global__ void find_descriptor(
530529
// get radial descriptors
531530
for (int i1 = 0; i1 < g_NN[n1]; ++i1) {
532531
int n2 = g_NL[n1 + N * i1];
533-
double x12double = g_x[n2] - x1;
534-
double y12double = g_y[n2] - y1;
535-
double z12double = g_z[n2] - z1;
536-
apply_mic(box, x12double, y12double, z12double);
537-
float x12 = float(x12double), y12 = float(y12double), z12 = float(z12double);
532+
float x12 = g_x[n2] - x1;
533+
float y12 = g_y[n2] - y1;
534+
float z12 = g_z[n2] - z1;
535+
apply_mic(box, x12, y12, z12);
538536
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
539537
float fc12;
540538
int t2 = g_type[n2];
@@ -560,11 +558,10 @@ static __global__ void find_descriptor(
560558
float s[NUM_OF_ABC] = {0.0f};
561559
for (int i1 = 0; i1 < g_NN_angular[n1]; ++i1) {
562560
int n2 = g_NL_angular[n1 + N * i1];
563-
double x12double = g_x[n2] - x1;
564-
double y12double = g_y[n2] - y1;
565-
double z12double = g_z[n2] - z1;
566-
apply_mic(box, x12double, y12double, z12double);
567-
float x12 = float(x12double), y12 = float(y12double), z12 = float(z12double);
561+
float x12 = g_x[n2] - x1;
562+
float y12 = g_y[n2] - y1;
563+
float z12 = g_z[n2] - z1;
564+
apply_mic(box, x12, y12, z12);
568565
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
569566
float fc12;
570567
int t2 = g_type[n2];
@@ -724,11 +721,11 @@ static __global__ void find_bec_radial(
724721
for (int i1 = 0; i1 < g_NN[n1]; ++i1) {
725722
int n2 = g_NL[n1 + N * i1];
726723
int t2 = g_type[n2];
727-
double x12double = g_x[n2] - x1;
728-
double y12double = g_y[n2] - y1;
729-
double z12double = g_z[n2] - z1;
730-
apply_mic(box, x12double, y12double, z12double);
731-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
724+
float x12 = g_x[n2] - x1;
725+
float y12 = g_y[n2] - y1;
726+
float z12 = g_z[n2] - z1;
727+
apply_mic(box, x12, y12, z12);
728+
float r12[3] = {x12, y12, z12};
732729
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
733730
float d12inv = 1.0f / d12;
734731
float fc12, fcp12;
@@ -823,11 +820,11 @@ static __global__ void find_bec_angular(
823820
double z1 = g_z[n1];
824821
for (int i1 = 0; i1 < g_NN_angular[n1]; ++i1) {
825822
int n2 = g_NL_angular[n1 + N * i1];
826-
double x12double = g_x[n2] - x1;
827-
double y12double = g_y[n2] - y1;
828-
double z12double = g_z[n2] - z1;
829-
apply_mic(box, x12double, y12double, z12double);
830-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
823+
float x12 = g_x[n2] - x1;
824+
float y12 = g_y[n2] - y1;
825+
float z12 = g_z[n2] - z1;
826+
apply_mic(box, x12, y12, z12);
827+
float r12[3] = {x12, y12, z12};
831828
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
832829
float f12[3] = {0.0f};
833830
float fc12, fcp12;
@@ -949,11 +946,11 @@ static __global__ void find_force_radial(
949946
for (int i1 = 0; i1 < g_NN[n1]; ++i1) {
950947
int n2 = g_NL[n1 + N * i1];
951948
int t2 = g_type[n2];
952-
double x12double = g_x[n2] - x1;
953-
double y12double = g_y[n2] - y1;
954-
double z12double = g_z[n2] - z1;
955-
apply_mic(box, x12double, y12double, z12double);
956-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
949+
float x12 = g_x[n2] - x1;
950+
float y12 = g_y[n2] - y1;
951+
float z12 = g_z[n2] - z1;
952+
apply_mic(box, x12, y12, z12);
953+
float r12[3] = {x12, y12, z12};
957954
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
958955
float d12inv = 1.0f / d12;
959956
float f12[3] = {0.0f};
@@ -1068,11 +1065,11 @@ static __global__ void find_partial_force_angular(
10681065
for (int i1 = 0; i1 < g_NN_angular[n1]; ++i1) {
10691066
int index = i1 * N + n1;
10701067
int n2 = g_NL_angular[n1 + N * i1];
1071-
double x12double = g_x[n2] - x1;
1072-
double y12double = g_y[n2] - y1;
1073-
double z12double = g_z[n2] - z1;
1074-
apply_mic(box, x12double, y12double, z12double);
1075-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
1068+
float x12 = g_x[n2] - x1;
1069+
float y12 = g_y[n2] - y1;
1070+
float z12 = g_z[n2] - z1;
1071+
apply_mic(box, x12, y12, z12);
1072+
float r12[3] = {x12, y12, z12};
10761073
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
10771074
float f12[3] = {0.0f};
10781075
float fc12, fcp12;
@@ -1155,11 +1152,11 @@ static __global__ void find_force_ZBL(
11551152
float pow_zi = pow(float(zi), 0.23f);
11561153
for (int i1 = 0; i1 < g_NN[n1]; ++i1) {
11571154
int n2 = g_NL[n1 + N * i1];
1158-
double x12double = g_x[n2] - x1;
1159-
double y12double = g_y[n2] - y1;
1160-
double z12double = g_z[n2] - z1;
1161-
apply_mic(box, x12double, y12double, z12double);
1162-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
1155+
float x12 = g_x[n2] - x1;
1156+
float y12 = g_y[n2] - y1;
1157+
float z12 = g_z[n2] - z1;
1158+
apply_mic(box, x12, y12, z12);
1159+
float r12[3] = {x12, y12, z12};
11631160
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
11641161
float d12inv = 1.0f / d12;
11651162
float f, fp;
@@ -1271,11 +1268,11 @@ static __global__ void find_force_charge_real_space(
12711268
int n2 = g_NL[n1 + N * i1];
12721269
float q2 = g_charge[n2];
12731270
float qq = q1 * q2;
1274-
double x12double = g_x[n2] - x1;
1275-
double y12double = g_y[n2] - y1;
1276-
double z12double = g_z[n2] - z1;
1277-
apply_mic(box, x12double, y12double, z12double);
1278-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
1271+
float x12 = g_x[n2] - x1;
1272+
float y12 = g_y[n2] - y1;
1273+
float z12 = g_z[n2] - z1;
1274+
apply_mic(box, x12, y12, z12);
1275+
float r12[3] = {x12, y12, z12};
12791276
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
12801277
float d12inv = 1.0f / d12;
12811278

@@ -1363,11 +1360,11 @@ static __global__ void find_force_vdw_static(
13631360
int n2 = g_NL[n1 + N * i1];
13641361
float q2 = g_charge[n2];
13651362
float qq = q1 * q1 * q2 * q2;
1366-
double x12double = g_x[n2] - x1;
1367-
double y12double = g_y[n2] - y1;
1368-
double z12double = g_z[n2] - z1;
1369-
apply_mic(box, x12double, y12double, z12double);
1370-
float r12[3] = {float(x12double), float(y12double), float(z12double)};
1363+
float x12 = g_x[n2] - x1;
1364+
float y12 = g_y[n2] - y1;
1365+
float z12 = g_z[n2] - z1;
1366+
apply_mic(box, x12, y12, z12);
1367+
float r12[3] = {x12, y12, z12};
13711368
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
13721369
float d12_2 = d12 * d12;
13731370
float d12_4 = d12_2 * d12_2;

src/force/nep_charge_small_box.cuh

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ static __device__ __inline__ double atomicAdd(double* address, double val)
3737
#endif
3838

3939
static __device__ void apply_mic_small_box(
40-
const Box& box, const NEP_Charge::ExpandedBox& ebox, double& x12, double& y12, double& z12)
40+
const Box& box, const NEP_Charge::ExpandedBox& ebox, float& x12, float& y12, float& z12)
4141
{
42-
double sx12 = ebox.h[9] * x12 + ebox.h[10] * y12 + ebox.h[11] * z12;
43-
double sy12 = ebox.h[12] * x12 + ebox.h[13] * y12 + ebox.h[14] * z12;
44-
double sz12 = ebox.h[15] * x12 + ebox.h[16] * y12 + ebox.h[17] * z12;
42+
float sx12 = ebox.h[9] * x12 + ebox.h[10] * y12 + ebox.h[11] * z12;
43+
float sy12 = ebox.h[12] * x12 + ebox.h[13] * y12 + ebox.h[14] * z12;
44+
float sz12 = ebox.h[15] * x12 + ebox.h[16] * y12 + ebox.h[17] * z12;
4545
if (box.pbc_x == 1)
4646
sx12 -= nearbyint(sx12);
4747
if (box.pbc_y == 1)
@@ -77,9 +77,9 @@ static __global__ void find_neighbor_list_small_box(
7777
{
7878
int n1 = blockIdx.x * blockDim.x + threadIdx.x + N1;
7979
if (n1 < N2) {
80-
double x1 = g_x[n1];
81-
double y1 = g_y[n1];
82-
double z1 = g_z[n1];
80+
float x1 = g_x[n1];
81+
float y1 = g_y[n1];
82+
float z1 = g_z[n1];
8383
int count_radial = 0;
8484
int count_angular = 0;
8585
for (int n2 = N1; n2 < N2; ++n2) {
@@ -90,14 +90,14 @@ static __global__ void find_neighbor_list_small_box(
9090
continue; // exclude self
9191
}
9292

93-
double delta[3];
94-
delta[0] = box.cpu_h[0] * ia + box.cpu_h[1] * ib + box.cpu_h[2] * ic;
95-
delta[1] = box.cpu_h[3] * ia + box.cpu_h[4] * ib + box.cpu_h[5] * ic;
96-
delta[2] = box.cpu_h[6] * ia + box.cpu_h[7] * ib + box.cpu_h[8] * ic;
93+
float delta[3];
94+
delta[0] = box.float_h[0] * ia + box.float_h[1] * ib + box.float_h[2] * ic;
95+
delta[1] = box.float_h[3] * ia + box.float_h[4] * ib + box.float_h[5] * ic;
96+
delta[2] = box.float_h[6] * ia + box.float_h[7] * ib + box.float_h[8] * ic;
9797

98-
double x12 = g_x[n2] + delta[0] - x1;
99-
double y12 = g_y[n2] + delta[1] - y1;
100-
double z12 = g_z[n2] + delta[2] - z1;
98+
float x12 = g_x[n2] + delta[0] - x1;
99+
float y12 = g_y[n2] + delta[1] - y1;
100+
float z12 = g_z[n2] + delta[2] - z1;
101101

102102
apply_mic_small_box(box, ebox, x12, y12, z12);
103103

@@ -108,16 +108,16 @@ static __global__ void find_neighbor_list_small_box(
108108

109109
if (distance_square < rc_radial * rc_radial) {
110110
g_NL_radial[count_radial * N + n1] = n2;
111-
g_x12_radial[count_radial * N + n1] = float(x12);
112-
g_y12_radial[count_radial * N + n1] = float(y12);
113-
g_z12_radial[count_radial * N + n1] = float(z12);
111+
g_x12_radial[count_radial * N + n1] = x12;
112+
g_y12_radial[count_radial * N + n1] = y12;
113+
g_z12_radial[count_radial * N + n1] = z12;
114114
count_radial++;
115115
}
116116
if (distance_square < rc_angular * rc_angular) {
117117
g_NL_angular[count_angular * N + n1] = n2;
118-
g_x12_angular[count_angular * N + n1] = float(x12);
119-
g_y12_angular[count_angular * N + n1] = float(y12);
120-
g_z12_angular[count_angular * N + n1] = float(z12);
118+
g_x12_angular[count_angular * N + n1] = x12;
119+
g_y12_angular[count_angular * N + n1] = y12;
120+
g_z12_angular[count_angular * N + n1] = z12;
121121
count_angular++;
122122
}
123123
}

0 commit comments

Comments
 (0)