@@ -108,20 +108,13 @@ void SNES::initialize_mu_and_sigma(Parameters& para)
108108 }
109109 // make sure the initial charges are zero
110110 if (para.charge_mode ) {
111- int num_outputs = 1 ;
112- if (para.charge_mode >= 1 && para.charge_mode <= 3 ) {
113- num_outputs = 2 ;
114- } else if (para.charge_mode == 4 ) {
115- num_outputs = 3 ;
116- }
117- const int num_full = (para.dim + 1 + num_outputs) * para.num_neurons1 ;
118111 const int num_part = (para.dim + 2 ) * para.num_neurons1 ;
119112 for (int t = 0 ; t < para.num_types ; ++t) {
120- for (int n = num_full * t + num_part; n < num_full * (t + 1 ); ++n) {
113+ for (int n = para. number_of_variables_ann_1 * t + num_part; n < para. number_of_variables_ann_1 * (t + 1 ); ++n) {
121114 mu[n] = 0 .0f ;
122115 }
123116 }
124- mu[num_full * para.num_types ] = 2 .0f ; // make sure initial sqrt(epsilon_inf) > 0
117+ mu[para. number_of_variables_ann_1 * para.num_types ] = 2 .0f ; // make sure initial sqrt(epsilon_inf) > 0
125118 }
126119 } else {
127120 for (int n = 0 ; n < number_of_variables; ++n) {
@@ -147,14 +140,7 @@ void SNES::initialize_mu_and_sigma_fine_tune(Parameters& para)
147140 };
148141 // read in the whole foundation file first
149142 const int NUM89 = 89 ;
150- int num_outputs = 1 ;
151- if (para.charge_mode >= 1 && para.charge_mode <= 3 ) {
152- num_outputs = 2 ;
153- } else if (para.charge_mode == 4 ) {
154- num_outputs = 3 ;
155- }
156- const int num_ann_per_element = (para.dim + 1 + num_outputs) * para.num_neurons1 ;
157- const int num_ann = NUM89 * num_ann_per_element + (para.charge_mode ? 2 : 1 );
143+ const int num_ann = NUM89 * para.number_of_variables_ann_1 + (para.charge_mode ? 2 : 1 );
158144 const int num_cnk_radial = NUM89 * NUM89 * (para.n_max_radial + 1 ) * (para.basis_size_radial + 1 );
159145 const int num_cnk_angular = NUM89 * NUM89 * (para.n_max_angular + 1 ) * (para.basis_size_angular + 1 );
160146 const int num_tot = num_ann + num_cnk_radial + num_cnk_angular;
@@ -182,9 +168,9 @@ void SNES::initialize_mu_and_sigma_fine_tune(Parameters& para)
182168 int count = 0 ;
183169 for (int i = 0 ; i < para.num_types ; ++ i) {
184170 int element_index = element_map[para.atomic_numbers [i] - 1 ];
185- for (int j = 0 ; j < num_ann_per_element ; ++j) {
186- mu[count] = restart_mu[element_index * num_ann_per_element + j];
187- sigma[count] = restart_sigma[element_index * num_ann_per_element + j];
171+ for (int j = 0 ; j < para. number_of_variables_ann_1 ; ++j) {
172+ mu[count] = restart_mu[element_index * para. number_of_variables_ann_1 + j];
173+ sigma[count] = restart_sigma[element_index * para. number_of_variables_ann_1 + j];
188174 ++count;
189175 }
190176 }
@@ -252,30 +238,22 @@ void SNES::calculate_utility()
252238
253239void SNES::find_type_of_variable (Parameters& para)
254240{
255- int num_outputs = 1 ;
256- if (para.charge_mode >= 1 && para.charge_mode <= 3 ) {
257- num_outputs = 2 ;
258- } else if (para.charge_mode == 4 ) {
259- num_outputs = 3 ;
260- }
261- int num_para_ann_per_type = (para.dim + 1 + num_outputs) * para.num_neurons1 ;
262-
263241 int offset = 0 ;
264242
265243 // NN part
266244 if (para.version != 3 ) {
267245 int num_ann = (para.train_mode == 2 ) ? 2 : 1 ;
268246 for (int ann = 0 ; ann < num_ann; ++ann) {
269247 for (int t = 0 ; t < para.num_types ; ++t) {
270- for (int n = 0 ; n < num_para_ann_per_type ; ++n) {
248+ for (int n = 0 ; n < para. number_of_variables_ann_1 ; ++n) {
271249 type_of_variable[n + offset] = t;
272250 }
273- offset += num_para_ann_per_type ;
251+ offset += para. number_of_variables_ann_1 ;
274252 }
275253 offset += para.charge_mode ? 2 : 1 ; // the bias
276254 }
277255 } else {
278- offset += num_para_ann_per_type + 1 ;
256+ offset += para. number_of_variables_ann_1 + 1 ;
279257 }
280258
281259 // descriptor part
@@ -473,7 +451,8 @@ static __global__ void gpu_find_L1_L2_NEP4(
473451 s_cost_L2reg[tid] = 0 .0f ;
474452 for (int v = tid; v < number_of_variables; v += blockDim .x ) {
475453 const float para = g_population[bid * number_of_variables + v];
476- if (g_type_of_variable[v] == g_type || g_type == g_num_types) {
454+ if ((g_type_of_variable[v] == g_type) && (g_type != g_num_types) ||
455+ (g_type_of_variable[v] != g_type) && (g_type == g_num_types)) {
477456 s_cost_L1reg[tid] += abs (para);
478457 s_cost_L2reg[tid] += para * para;
479458 }
0 commit comments