Skip to content

Commit 4ebefc7

Browse files
authored
Merge pull request brucefan1983#1169 from brucefan1983/improve_charge_train
Improve charge train
2 parents 242feef + 42ea205 commit 4ebefc7

4 files changed

Lines changed: 18 additions & 41 deletions

File tree

src/main_nep/parameters.cu

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,16 @@ void Parameters::calculate_parameters()
214214
}
215215

216216
if (version == 3) {
217+
number_of_variables_ann_1 = (dim + 2) * num_neurons1;
217218
number_of_variables_ann = (dim + 2) * num_neurons1 + 1;
218219
} else if (version == 4) {
220+
number_of_variables_ann_1 = (dim + 2) * num_neurons1;
219221
number_of_variables_ann = (dim + 2) * num_neurons1 * num_types + 1;
220222
if (charge_mode) {
223+
number_of_variables_ann_1 += num_neurons1;
221224
number_of_variables_ann += num_neurons1 * num_types + 1;
222225
if (charge_mode == 4) {
226+
number_of_variables_ann_1 += num_neurons1;
223227
number_of_variables_ann += num_neurons1 * num_types;
224228
}
225229
}
@@ -258,14 +262,7 @@ void Parameters::calculate_parameters()
258262
}
259263
std::vector<std::string> tokens;
260264
const int NUM89 = 89;
261-
int num_outputs = 1;
262-
if (charge_mode >= 1 && charge_mode <= 3) {
263-
num_outputs = 2;
264-
} else if (charge_mode == 4) {
265-
num_outputs = 3;
266-
}
267-
const int num_ann_per_element = (dim + 1 + num_outputs) * num_neurons1;
268-
const int num_ann = NUM89 * num_ann_per_element + (charge_mode ? 2 : 1);
265+
const int num_ann = NUM89 * number_of_variables_ann_1 + (charge_mode ? 2 : 1);
269266
const int num_cnk_radial = NUM89 * NUM89 * (n_max_radial + 1) * (basis_size_radial + 1);
270267
const int num_cnk_angular = NUM89 * NUM89 * (n_max_angular + 1) * (basis_size_angular + 1);
271268
const int num_tot = num_ann + num_cnk_radial + num_cnk_angular;

src/main_nep/parameters.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public:
106106
int dim_angular; // number of angular descriptor components
107107
int number_of_variables; // total number of parameters (NN and descriptor)
108108
int number_of_variables_ann; // number of parameters in the ANN only
109+
int number_of_variables_ann_1; // number of parameters in the ANN for one element
109110
int number_of_variables_descriptor; // number of parameters in the descriptor only
110111

111112
// some arrays

src/main_nep/snes.cu

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,13 @@ void SNES::initialize_mu_and_sigma(Parameters& para)
108108
}
109109
// make sure the initial charges are zero
110110
if (para.charge_mode) {
111-
int num_outputs = 1;
112-
if (para.charge_mode >= 1 && para.charge_mode <= 3) {
113-
num_outputs = 2;
114-
} else if (para.charge_mode == 4) {
115-
num_outputs = 3;
116-
}
117-
const int num_full = (para.dim + 1 + num_outputs) * para.num_neurons1;
118111
const int num_part = (para.dim + 2) * para.num_neurons1;
119112
for (int t = 0; t < para.num_types; ++t) {
120-
for (int n = num_full * t + num_part; n < num_full * (t + 1); ++n) {
113+
for (int n = para.number_of_variables_ann_1 * t + num_part; n < para.number_of_variables_ann_1 * (t + 1); ++n) {
121114
mu[n] = 0.0f;
122115
}
123116
}
124-
mu[num_full * para.num_types] = 2.0f; // make sure initial sqrt(epsilon_inf) > 0
117+
mu[para.number_of_variables_ann_1 * para.num_types] = 2.0f; // make sure initial sqrt(epsilon_inf) > 0
125118
}
126119
} else {
127120
for (int n = 0; n < number_of_variables; ++n) {
@@ -147,14 +140,7 @@ void SNES::initialize_mu_and_sigma_fine_tune(Parameters& para)
147140
};
148141
// read in the whole foundation file first
149142
const int NUM89 = 89;
150-
int num_outputs = 1;
151-
if (para.charge_mode >= 1 && para.charge_mode <= 3) {
152-
num_outputs = 2;
153-
} else if (para.charge_mode == 4) {
154-
num_outputs = 3;
155-
}
156-
const int num_ann_per_element = (para.dim + 1 + num_outputs) * para.num_neurons1;
157-
const int num_ann = NUM89 * num_ann_per_element + (para.charge_mode ? 2 : 1);
143+
const int num_ann = NUM89 * para.number_of_variables_ann_1 + (para.charge_mode ? 2 : 1);
158144
const int num_cnk_radial = NUM89 * NUM89 * (para.n_max_radial + 1) * (para.basis_size_radial + 1);
159145
const int num_cnk_angular = NUM89 * NUM89 * (para.n_max_angular + 1) * (para.basis_size_angular + 1);
160146
const int num_tot = num_ann + num_cnk_radial + num_cnk_angular;
@@ -182,9 +168,9 @@ void SNES::initialize_mu_and_sigma_fine_tune(Parameters& para)
182168
int count = 0;
183169
for (int i = 0; i < para.num_types; ++ i) {
184170
int element_index = element_map[para.atomic_numbers[i] - 1];
185-
for (int j = 0; j < num_ann_per_element; ++j) {
186-
mu[count] = restart_mu[element_index * num_ann_per_element + j];
187-
sigma[count] = restart_sigma[element_index * num_ann_per_element + j];
171+
for (int j = 0; j < para.number_of_variables_ann_1; ++j) {
172+
mu[count] = restart_mu[element_index * para.number_of_variables_ann_1 + j];
173+
sigma[count] = restart_sigma[element_index * para.number_of_variables_ann_1 + j];
188174
++count;
189175
}
190176
}
@@ -252,30 +238,22 @@ void SNES::calculate_utility()
252238

253239
void SNES::find_type_of_variable(Parameters& para)
254240
{
255-
int num_outputs = 1;
256-
if (para.charge_mode >= 1 && para.charge_mode <= 3) {
257-
num_outputs = 2;
258-
} else if (para.charge_mode == 4) {
259-
num_outputs = 3;
260-
}
261-
int num_para_ann_per_type = (para.dim + 1 + num_outputs) * para.num_neurons1;
262-
263241
int offset = 0;
264242

265243
// NN part
266244
if (para.version != 3) {
267245
int num_ann = (para.train_mode == 2) ? 2 : 1;
268246
for (int ann = 0; ann < num_ann; ++ann) {
269247
for (int t = 0; t < para.num_types; ++t) {
270-
for (int n = 0; n < num_para_ann_per_type; ++n) {
248+
for (int n = 0; n < para.number_of_variables_ann_1; ++n) {
271249
type_of_variable[n + offset] = t;
272250
}
273-
offset += num_para_ann_per_type;
251+
offset += para.number_of_variables_ann_1;
274252
}
275253
offset += para.charge_mode ? 2 : 1; // the bias
276254
}
277255
} else {
278-
offset += num_para_ann_per_type + 1;
256+
offset += para.number_of_variables_ann_1 + 1;
279257
}
280258

281259
// descriptor part
@@ -473,7 +451,8 @@ static __global__ void gpu_find_L1_L2_NEP4(
473451
s_cost_L2reg[tid] = 0.0f;
474452
for (int v = tid; v < number_of_variables; v += blockDim.x) {
475453
const float para = g_population[bid * number_of_variables + v];
476-
if (g_type_of_variable[v] == g_type || g_type == g_num_types) {
454+
if ((g_type_of_variable[v] == g_type) && (g_type != g_num_types) ||
455+
(g_type_of_variable[v] != g_type) && (g_type == g_num_types)) {
477456
s_cost_L1reg[tid] += abs(para);
478457
s_cost_L2reg[tid] += para * para;
479458
}

src/makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ HEADERS = \
8383
###########################################################
8484
# executables
8585
###########################################################
86-
all: gpumd nep gnep
86+
all: gpumd nep
8787
gpumd: $(OBJ_GPUMD)
8888
$(CC) $(LDFLAGS) $^ -o $@ $(LIBS)
8989
@echo =================================================

0 commit comments

Comments
 (0)