Skip to content

Commit 284cdf1

Browse files
authored
Merge pull request brucefan1983#1348 from brucefan1983/eam_speedup
Eam speedup
2 parents d604a01 + 9a09b6b commit 284cdf1

7 files changed

Lines changed: 381 additions & 379 deletions

File tree

src/force/eam.cu

Lines changed: 104 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,17 @@ void EAM::initialize_eam2004zhou(FILE* fid, int num_types)
5959

6060
potential_model = 0;
6161

62-
rc = 0.0;
62+
rc = 0.0f;
6363
for (int type = 0; type < num_types; ++type) {
64-
double x[21];
64+
float x[21];
6565
for (int n = 0; n < 21; n++) {
66-
int count = fscanf(fid, "%lf", &x[n]);
66+
int count = fscanf(fid, "%f", &x[n]);
6767
PRINT_SCANF_ERROR(count, 1, "Reading error for EAM potential.");
6868
}
69-
eam2004zhou.re_inv[type] = 1.0 / x[0];
69+
eam2004zhou.re_inv[type] = 1.0f / x[0];
7070
eam2004zhou.fe[type] = x[1];
71-
eam2004zhou.rho_e_inv[type] = 1.0 / x[2];
72-
eam2004zhou.rho_s_inv[type] = 1.0 / x[3];
71+
eam2004zhou.rho_e_inv[type] = 1.0f / x[2];
72+
eam2004zhou.rho_s_inv[type] = 1.0f / x[3];
7373
eam2004zhou.alpha[type] = x[4];
7474
eam2004zhou.beta[type] = x[5];
7575
eam2004zhou.A[type] = x[6];
@@ -89,7 +89,7 @@ void EAM::initialize_eam2004zhou(FILE* fid, int num_types)
8989
eam2004zhou.rc[type] = x[20];
9090
eam2004zhou.rho_n[type] = x[2] * 0.85;
9191
eam2004zhou.rho_0[type] = x[2] * 1.15;
92-
eam2004zhou.rho_n_inv[type] = 1.0 / eam2004zhou.rho_n[type];
92+
eam2004zhou.rho_n_inv[type] = 1.0f / eam2004zhou.rho_n[type];
9393
if (rc < eam2004zhou.rc[type]) {
9494
rc = eam2004zhou.rc[type];
9595
}
@@ -106,9 +106,9 @@ void EAM::initialize_eam2006dai(FILE* fid)
106106

107107
potential_model = 1;
108108

109-
double x[9];
109+
float x[9];
110110
for (int n = 0; n < 9; n++) {
111-
int count = fscanf(fid, "%lf", &x[n]);
111+
int count = fscanf(fid, "%f", &x[n]);
112112
PRINT_SCANF_ERROR(count, 1, "Reading error for EAM potential.");
113113
}
114114
eam2006dai.A = x[0];
@@ -131,153 +131,153 @@ EAM::~EAM(void)
131131

132132
// pair function (phi and phip have been intentionally halved here)
133133
static __device__ void
134-
find_phi(const EAM2004Zhou& eam, const int type, const double d12, double& phi, double& phip)
134+
find_phi(const EAM2004Zhou& eam, const int type, const float d12, float& phi, float& phip)
135135
{
136-
double r_ratio = d12 * eam.re_inv[type];
137-
double tmp1 = (r_ratio - eam.kappa[type]) * (r_ratio - eam.kappa[type]); // 2
136+
float r_ratio = d12 * eam.re_inv[type];
137+
float tmp1 = (r_ratio - eam.kappa[type]) * (r_ratio - eam.kappa[type]); // 2
138138
tmp1 *= tmp1; // 4
139139
tmp1 *= tmp1 * tmp1 * tmp1 * tmp1; // 20
140-
double tmp2 = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
140+
float tmp2 = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
141141
tmp2 *= tmp2; // 4
142142
tmp2 *= tmp2 * tmp2 * tmp2 * tmp2; // 20
143-
double phi1 = 0.5 * eam.A[type] * exp(-eam.alpha[type] * (r_ratio - 1.0)) / (1.0 + tmp1);
144-
double phi2 = 0.5 * eam.B[type] * exp(-eam.beta[type] * (r_ratio - 1.0)) / (1.0 + tmp2);
143+
float phi1 = 0.5f * eam.A[type] * exp(-eam.alpha[type] * (r_ratio - 1.0f)) / (1.0f + tmp1);
144+
float phi2 = 0.5f * eam.B[type] * exp(-eam.beta[type] * (r_ratio - 1.0f)) / (1.0f + tmp2);
145145
phi = phi1 - phi2;
146146
phip = (phi2 * eam.re_inv[type]) *
147-
(eam.beta[type] + 20.0 * tmp2 / (r_ratio - eam.lambda[type]) / (1.0 + tmp2)) -
147+
(eam.beta[type] + 20.0f * tmp2 / (r_ratio - eam.lambda[type]) / (1.0f + tmp2)) -
148148
(phi1 * eam.re_inv[type]) *
149-
(eam.alpha[type] + 20.0 * tmp1 / (r_ratio - eam.kappa[type]) / (1.0 + tmp1));
149+
(eam.alpha[type] + 20.0f * tmp1 / (r_ratio - eam.kappa[type]) / (1.0f + tmp1));
150150
}
151151

152152
// density function f(r)
153-
static __device__ void find_f(const EAM2004Zhou& eam, const int type, const double d12, double& f)
153+
static __device__ void find_f(const EAM2004Zhou& eam, const int type, const float d12, float& f)
154154
{
155-
double r_ratio = d12 * eam.re_inv[type];
156-
double tmp = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
155+
float r_ratio = d12 * eam.re_inv[type];
156+
float tmp = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
157157
tmp *= tmp; // 4
158158
tmp *= tmp * tmp * tmp * tmp; // 20
159-
f = eam.fe[type] * exp(-eam.beta[type] * (r_ratio - 1.0)) / (1.0 + tmp);
159+
f = eam.fe[type] * exp(-eam.beta[type] * (r_ratio - 1.0f)) / (1.0f + tmp);
160160
}
161161

162162
// derivative of the density function f'(r)
163-
static __device__ void find_fp(const EAM2004Zhou& eam, const int type, const double d12, double& fp)
163+
static __device__ void find_fp(const EAM2004Zhou& eam, const int type, const float d12, float& fp)
164164
{
165-
double r_ratio = d12 * eam.re_inv[type];
166-
double tmp = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
165+
float r_ratio = d12 * eam.re_inv[type];
166+
float tmp = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
167167
tmp *= tmp; // 4
168168
tmp *= tmp * tmp * tmp * tmp; // 20
169-
double f = eam.fe[type] * exp(-eam.beta[type] * (r_ratio - 1.0)) / (1.0 + tmp);
169+
float f = eam.fe[type] * exp(-eam.beta[type] * (r_ratio - 1.0f)) / (1.0f + tmp);
170170
fp = -(f * eam.re_inv[type]) *
171-
(eam.beta[type] + 20.0 * tmp / (r_ratio - eam.lambda[type]) / (1.0 + tmp));
171+
(eam.beta[type] + 20.0f * tmp / (r_ratio - eam.lambda[type]) / (1.0f + tmp));
172172
}
173173

174174
static __device__ void
175-
find_f_and_fp(const EAM2004Zhou& eam, const int type, const double d12, double& f, double& fp)
175+
find_f_and_fp(const EAM2004Zhou& eam, const int type, const float d12, float& f, float& fp)
176176
{
177-
double r_ratio = d12 * eam.re_inv[type];
178-
double tmp = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
177+
float r_ratio = d12 * eam.re_inv[type];
178+
float tmp = (r_ratio - eam.lambda[type]) * (r_ratio - eam.lambda[type]); // 2
179179
tmp *= tmp; // 4
180180
tmp *= tmp * tmp * tmp * tmp; // 20
181-
f = eam.fe[type] * exp(-eam.beta[type] * (r_ratio - 1.0)) / (1.0 + tmp);
181+
f = eam.fe[type] * exp(-eam.beta[type] * (r_ratio - 1.0f)) / (1.0f + tmp);
182182
fp = -(f * eam.re_inv[type]) *
183-
(eam.beta[type] + 20.0 * tmp / (r_ratio - eam.lambda[type]) / (1.0 + tmp));
183+
(eam.beta[type] + 20.0f * tmp / (r_ratio - eam.lambda[type]) / (1.0f + tmp));
184184
}
185185

186186
// pair function for EAM2004Zhou
187187
static __device__ void find_phi(
188188
const EAM2004Zhou& eam,
189189
const int type1,
190190
const int type2,
191-
const double d12,
192-
double& phi,
193-
double& phip)
191+
const float d12,
192+
float& phi,
193+
float& phip)
194194
{
195195
if (type1 == type2) {
196196
find_phi(eam, type1, d12, phi, phip);
197197
} else {
198-
double phi1, phip1;
198+
float phi1, phip1;
199199
find_phi(eam, type1, d12, phi1, phip1);
200-
double phi2, phip2;
200+
float phi2, phip2;
201201
find_phi(eam, type2, d12, phi2, phip2);
202-
double f1, fp1;
202+
float f1, fp1;
203203
find_f_and_fp(eam, type1, d12, f1, fp1);
204-
double f2, fp2;
204+
float f2, fp2;
205205
find_f_and_fp(eam, type2, d12, f2, fp2);
206-
double f1inv = 1.0 / f1;
207-
double f2inv = 1.0 / f2;
208-
phi = 0.5 * (phi1 * f2 * f1inv + phi2 * f1 * f2inv);
206+
float f1inv = 1.0f / f1;
207+
float f2inv = 1.0f / f2;
208+
phi = 0.5f * (phi1 * f2 * f1inv + phi2 * f1 * f2inv);
209209
phip = (phip1 * f2 + phi1 * (fp2 - f2 * fp1 * f1inv)) * f1inv;
210210
phip += (phip2 * f1 + phi2 * (fp1 - f1 * fp2 * f2inv)) * f2inv;
211-
phip *= 0.5;
211+
phip *= 0.5f;
212212
}
213213
}
214214

215215
// embedding function
216216
static __device__ void
217-
find_F(const EAM2004Zhou& eam, const int type, const double rho, double& F, double& Fp)
217+
find_F(const EAM2004Zhou& eam, const int type, const float rho, float& F, float& Fp)
218218
{
219219
if (rho < eam.rho_n[type]) {
220-
double x = rho * eam.rho_n_inv[type] - 1.0;
220+
float x = rho * eam.rho_n_inv[type] - 1.0f;
221221
F = ((eam.Fn3[type] * x + eam.Fn2[type]) * x + eam.Fn1[type]) * x + eam.Fn0[type];
222-
Fp = ((3.0 * eam.Fn3[type] * x + 2.0 * eam.Fn2[type]) * x + eam.Fn1[type]) / eam.rho_n[type];
222+
Fp = ((3.0f * eam.Fn3[type] * x + 2.0f * eam.Fn2[type]) * x + eam.Fn1[type]) / eam.rho_n[type];
223223
} else if (rho < eam.rho_0[type]) {
224-
double x = rho * eam.rho_e_inv[type] - 1.0;
224+
float x = rho * eam.rho_e_inv[type] - 1.0f;
225225
F = ((eam.F3[type] * x + eam.F2[type]) * x + eam.F1[type]) * x + eam.F0[type];
226-
Fp = ((3.0 * eam.F3[type] * x + 2.0 * eam.F2[type]) * x + eam.F1[type]) * eam.rho_e_inv[type];
226+
Fp = ((3.0f * eam.F3[type] * x + 2.0f * eam.F2[type]) * x + eam.F1[type]) * eam.rho_e_inv[type];
227227
} else {
228-
double x = rho * eam.rho_s_inv[type];
229-
double x_eta = pow(x, eam.eta[type]);
230-
F = eam.Fe[type] * (1.0 - eam.eta[type] * log(x)) * x_eta;
228+
float x = rho * eam.rho_s_inv[type];
229+
float x_eta = pow(x, eam.eta[type]);
230+
F = eam.Fe[type] * (1.0f - eam.eta[type] * log(x)) * x_eta;
231231
Fp = (eam.eta[type] / rho) * (F - eam.Fe[type] * x_eta);
232232
}
233233
}
234234

235235
// pair function (phi and phip have been intentionally halved here)
236-
static __device__ void find_phi(const EAM2006Dai& fs, const double d12, double& phi, double& phip)
236+
static __device__ void find_phi(const EAM2006Dai& fs, const float d12, float& phi, float& phip)
237237
{
238238
if (d12 > fs.c) {
239-
phi = 0.0;
240-
phip = 0.0;
239+
phi = 0.0f;
240+
phip = 0.0f;
241241
} else {
242-
double tmp = ((((fs.c4 * d12 + fs.c3) * d12 + fs.c2) * d12 + fs.c1) * d12 + fs.c0);
242+
float tmp = ((((fs.c4 * d12 + fs.c3) * d12 + fs.c2) * d12 + fs.c1) * d12 + fs.c0);
243243

244-
phi = 0.5 * (d12 - fs.c) * (d12 - fs.c) * tmp;
244+
phi = 0.5f * (d12 - fs.c) * (d12 - fs.c) * tmp;
245245

246-
phip = 2.0 * (d12 - fs.c) * tmp;
247-
phip += (((4.0 * fs.c4 * d12 + 3.0 * fs.c3) * d12 + 2.0 * fs.c2) * d12 + fs.c1) * (d12 - fs.c) *
246+
phip = 2.0f * (d12 - fs.c) * tmp;
247+
phip += (((4.0f * fs.c4 * d12 + 3.0f * fs.c3) * d12 + 2.0f * fs.c2) * d12 + fs.c1) * (d12 - fs.c) *
248248
(d12 - fs.c);
249-
phip *= 0.5;
249+
phip *= 0.5f;
250250
}
251251
}
252252

253253
// density function f(r)
254-
static __device__ void find_f(const EAM2006Dai& fs, const double d12, double& f)
254+
static __device__ void find_f(const EAM2006Dai& fs, const float d12, float& f)
255255
{
256256
if (d12 > fs.d) {
257-
f = 0.0;
257+
f = 0.0f;
258258
} else {
259-
double tmp = (d12 - fs.d) * (d12 - fs.d);
259+
float tmp = (d12 - fs.d) * (d12 - fs.d);
260260
f = tmp + fs.B * fs.B * tmp * tmp;
261261
}
262262
}
263263

264264
// derivative of the density function f'(r)
265-
static __device__ void find_fp(const EAM2006Dai& fs, const double d12, double& fp)
265+
static __device__ void find_fp(const EAM2006Dai& fs, const float d12, float& fp)
266266
{
267267
if (d12 > fs.d) {
268-
fp = 0.0;
268+
fp = 0.0f;
269269
} else {
270-
double tmp = 2.0 * (d12 - fs.d);
271-
fp = tmp * (1.0 + fs.B * fs.B * tmp * (d12 - fs.d));
270+
float tmp = 2.0f * (d12 - fs.d);
271+
fp = tmp * (1.0f + fs.B * fs.B * tmp * (d12 - fs.d));
272272
}
273273
}
274274

275275
// embedding function
276-
static __device__ void find_F(const EAM2006Dai& fs, const double rho, double& F, double& Fp)
276+
static __device__ void find_F(const EAM2006Dai& fs, const float rho, float& F, float& Fp)
277277
{
278-
double sqrt_rho = sqrt(rho);
278+
float sqrt_rho = sqrt(rho);
279279
F = -fs.A * sqrt_rho;
280-
Fp = -fs.A * 0.5 / sqrt_rho;
280+
Fp = -fs.A * 0.5f / sqrt_rho;
281281
}
282282

283283
// Calculate the embedding energy and its derivative
@@ -295,7 +295,7 @@ static __global__ void find_force_eam_step1(
295295
const double* __restrict__ g_x,
296296
const double* __restrict__ g_y,
297297
const double* __restrict__ g_z,
298-
double* g_Fp,
298+
float* g_Fp,
299299
double* g_pe)
300300
{
301301
int n1 = blockIdx.x * blockDim.x + threadIdx.x + N1; // particle index
@@ -308,15 +308,15 @@ static __global__ void find_force_eam_step1(
308308
double z1 = g_z[n1];
309309

310310
// Calculate the density
311-
double rho = 0.0;
311+
float rho = 0.0f;
312312
for (int i1 = 0; i1 < NN; ++i1) {
313313
int n2 = g_NL[n1 + N * i1];
314-
double x12 = g_x[n2] - x1;
315-
double y12 = g_y[n2] - y1;
316-
double z12 = g_z[n2] - z1;
314+
float x12 = g_x[n2] - x1;
315+
float y12 = g_y[n2] - y1;
316+
float z12 = g_z[n2] - z1;
317317
apply_mic(box, x12, y12, z12);
318-
double d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
319-
double rho12 = 0.0;
318+
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
319+
float rho12 = 0.0f;
320320
if (potential_model == 0) {
321321
find_f(eam2004zhou, g_type[n2], d12, rho12); // density is contributed by n2
322322
}
@@ -327,7 +327,7 @@ static __global__ void find_force_eam_step1(
327327
}
328328

329329
// Calculate the embedding energy F and its derivative Fp
330-
double F, Fp;
330+
float F, Fp;
331331
if (potential_model == 0)
332332
find_F(eam2004zhou, g_type[n1], rho, F, Fp); // embedding energy is for n1
333333
if (potential_model == 1)
@@ -350,7 +350,7 @@ static __global__ void find_force_eam_step2(
350350
const int* g_NN,
351351
const int* g_NL,
352352
const int* g_type,
353-
const double* __restrict__ g_Fp,
353+
const float* __restrict__ g_Fp,
354354
const double* __restrict__ g_x,
355355
const double* __restrict__ g_y,
356356
const double* __restrict__ g_z,
@@ -361,39 +361,39 @@ static __global__ void find_force_eam_step2(
361361
double* g_pe)
362362
{
363363
int n1 = blockIdx.x * blockDim.x + threadIdx.x + N1;
364-
double s_fx = 0.0; // force_x
365-
double s_fy = 0.0; // force_y
366-
double s_fz = 0.0; // force_z
367-
double s_pe = 0.0; // potential energy
368-
double s_sxx = 0.0; // virial_stress_xx
369-
double s_sxy = 0.0; // virial_stress_xy
370-
double s_sxz = 0.0; // virial_stress_xz
371-
double s_syx = 0.0; // virial_stress_yx
372-
double s_syy = 0.0; // virial_stress_yy
373-
double s_syz = 0.0; // virial_stress_yz
374-
double s_szx = 0.0; // virial_stress_zx
375-
double s_szy = 0.0; // virial_stress_zy
376-
double s_szz = 0.0; // virial_stress_zz
364+
float s_fx = 0.0f; // force_x
365+
float s_fy = 0.0f; // force_y
366+
float s_fz = 0.0f; // force_z
367+
float s_pe = 0.0f; // potential energy
368+
float s_sxx = 0.0f; // virial_stress_xx
369+
float s_sxy = 0.0f; // virial_stress_xy
370+
float s_sxz = 0.0f; // virial_stress_xz
371+
float s_syx = 0.0f; // virial_stress_yx
372+
float s_syy = 0.0f; // virial_stress_yy
373+
float s_syz = 0.0f; // virial_stress_yz
374+
float s_szx = 0.0f; // virial_stress_zx
375+
float s_szy = 0.0f; // virial_stress_zy
376+
float s_szz = 0.0f; // virial_stress_zz
377377

378378
if (n1 < N2) {
379379
int type1 = g_type[n1];
380380
int NN = g_NN[n1];
381381
double x1 = g_x[n1];
382382
double y1 = g_y[n1];
383383
double z1 = g_z[n1];
384-
double Fp1 = g_Fp[n1];
384+
float Fp1 = g_Fp[n1];
385385

386386
for (int i1 = 0; i1 < NN; ++i1) {
387387
int n2 = g_NL[n1 + N * i1];
388388
int type2 = g_type[n2];
389-
double Fp2 = g_Fp[n2];
390-
double x12 = g_x[n2] - x1;
391-
double y12 = g_y[n2] - y1;
392-
double z12 = g_z[n2] - z1;
389+
float Fp2 = g_Fp[n2];
390+
float x12 = g_x[n2] - x1;
391+
float y12 = g_y[n2] - y1;
392+
float z12 = g_z[n2] - z1;
393393
apply_mic(box, x12, y12, z12);
394-
double d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
394+
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
395395

396-
double phi, phip, fp1, fp2;
396+
float phi, phip, fp1, fp2;
397397
if (potential_model == 0) {
398398
find_phi(eam2004zhou, type1, type2, d12, phi, phip);
399399
if (type1 == type2) {
@@ -410,16 +410,16 @@ static __global__ void find_force_eam_step2(
410410
fp2 = fp1;
411411
}
412412

413-
double d12inv = 1.0 / d12;
413+
float d12inv = 1.0f / d12;
414414
phip *= d12inv;
415415
fp1 *= d12inv;
416416
fp2 *= d12inv;
417-
double f12x = x12 * (phip + Fp1 * fp2);
418-
double f12y = y12 * (phip + Fp1 * fp2);
419-
double f12z = z12 * (phip + Fp1 * fp2);
420-
double f21x = -x12 * (phip + Fp2 * fp1);
421-
double f21y = -y12 * (phip + Fp2 * fp1);
422-
double f21z = -z12 * (phip + Fp2 * fp1);
417+
float f12x = x12 * (phip + Fp1 * fp2);
418+
float f12y = y12 * (phip + Fp1 * fp2);
419+
float f12z = z12 * (phip + Fp1 * fp2);
420+
float f21x = -x12 * (phip + Fp2 * fp1);
421+
float f21y = -y12 * (phip + Fp2 * fp1);
422+
float f21z = -z12 * (phip + Fp2 * fp1);
423423

424424
// two-body potential energy
425425
s_pe += phi;

0 commit comments

Comments
 (0)