Skip to content

Commit e644c9b

Browse files
committed
Use cfg_select! for code based on target_feature
Ups the Rust requirement to 1.95, so may be a while before this could be merged. But is a _much_ nicer way of expressing this pattern.
1 parent b2ad933 commit e644c9b

7 files changed

Lines changed: 269 additions & 229 deletions

File tree

src/nnue.rs

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,38 @@ use crate::{
1212
};
1313

1414
mod forward {
15-
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
16-
mod vectorized;
17-
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
18-
pub use vectorized::*;
19-
20-
#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))]
21-
mod scalar;
22-
#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))]
23-
pub use scalar::*;
15+
cfg_select! {
16+
any(target_feature = "avx2", target_feature = "neon") => {
17+
mod vectorized;
18+
pub use vectorized::*;
19+
}
20+
_ => {
21+
mod scalar;
22+
pub use scalar::*;
23+
}
24+
}
2425
}
2526

2627
mod simd {
27-
#[cfg(target_feature = "avx512f")]
28-
mod avx512;
29-
#[cfg(target_feature = "avx512f")]
30-
pub use avx512::*;
31-
32-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
33-
mod avx2;
34-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
35-
pub use avx2::*;
36-
37-
#[cfg(all(target_feature = "neon", not(any(target_feature = "avx2", target_feature = "avx512f"))))]
38-
mod neon;
39-
#[cfg(all(target_feature = "neon", not(any(target_feature = "avx2", target_feature = "avx512f"))))]
40-
pub use neon::*;
41-
42-
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2", target_feature = "neon")))]
43-
mod scalar;
44-
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2", target_feature = "neon")))]
45-
pub use scalar::*;
28+
cfg_select! {
29+
target_feature = "avx512f" => {
30+
mod avx512;
31+
pub use avx512::*;
32+
}
33+
target_feature = "avx2" => {
34+
mod avx2;
35+
pub use avx2::*;
36+
}
37+
target_feature = "neon" => {
38+
mod neon;
39+
pub use neon::*;
40+
}
41+
_ => {
42+
mod scalar;
43+
pub use scalar::*;
44+
}
45+
}
46+
4647
}
4748

4849
const NETWORK_SCALE: i32 = 380;
@@ -57,10 +58,14 @@ const L3_SIZE: usize = 32;
5758
const FT_QUANT: i32 = 255;
5859
const L1_QUANT: i32 = 64;
5960

60-
#[cfg(target_feature = "avx512f")]
61-
const FT_SHIFT: u32 = 9;
62-
#[cfg(not(target_feature = "avx512f"))]
63-
const FT_SHIFT: i32 = 9;
61+
cfg_select! {
62+
target_feature = "avx512f" => {
63+
const FT_SHIFT: u32 = 9;
64+
}
65+
_ => {
66+
const FT_SHIFT: i32 = 9;
67+
}
68+
}
6469

6570
const DEQUANT_MULTIPLIER: f32 = (1 << FT_SHIFT) as f32 / (FT_QUANT * FT_QUANT * L1_QUANT) as f32;
6671

src/nnue/accumulator/psq.rs

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -66,39 +66,42 @@ impl PstAccumulator {
6666
self.accurate[pov] = true;
6767
}
6868

69-
#[inline]
70-
#[cfg(not(target_feature = "avx512vbmi2"))]
71-
fn push_features(
72-
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
73-
pov: Color,
74-
) {
75-
for square in bb {
76-
features.push(pst_index(color, piece_type, square, king, pov));
69+
cfg_select! {
70+
target_feature = "avx512vbmi2" => {
71+
#[inline]
72+
fn push_features(
73+
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
74+
pov: Color,
75+
) {
76+
unsafe {
77+
use std::arch::x86_64::*;
78+
79+
let base = pst_index(color, piece_type, Square::new(0), king, pov);
80+
81+
let iota = _mm512_set_epi8(
82+
63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38,
83+
37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12,
84+
11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
85+
);
86+
let squares = _mm512_castsi512_si128(_mm512_maskz_compress_epi8(bb.0, iota));
87+
let to_write = _mm256_xor_si256(_mm256_set1_epi16(base as i16), _mm256_cvtepu8_epi16(squares));
88+
features.unchecked_write(|data| {
89+
_mm256_storeu_si256(data.cast(), to_write);
90+
bb.count()
91+
});
92+
}
93+
}
7794
}
78-
}
79-
80-
#[inline]
81-
#[cfg(target_feature = "avx512vbmi2")]
82-
fn push_features(
83-
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
84-
pov: Color,
85-
) {
86-
unsafe {
87-
use std::arch::x86_64::*;
88-
89-
let base = pst_index(color, piece_type, Square::new(0), king, pov);
90-
91-
let iota = _mm512_set_epi8(
92-
63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38,
93-
37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12,
94-
11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
95-
);
96-
let squares = _mm512_castsi512_si128(_mm512_maskz_compress_epi8(bb.0, iota));
97-
let to_write = _mm256_xor_si256(_mm256_set1_epi16(base as i16), _mm256_cvtepu8_epi16(squares));
98-
features.unchecked_write(|data| {
99-
_mm256_storeu_si256(data.cast(), to_write);
100-
bb.count()
101-
});
95+
_ => {
96+
#[inline]
97+
fn push_features(
98+
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
99+
pov: Color,
100+
) {
101+
for square in bb {
102+
features.push(pst_index(color, piece_type, square, king, pov));
103+
}
104+
}
102105
}
103106
}
104107

src/nnue/accumulator/threats.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ use crate::{
88
mod threat_index;
99
pub use threat_index::*;
1010

11-
#[cfg(not(target_feature = "avx2"))]
12-
mod scalar;
13-
#[cfg(not(target_feature = "avx2"))]
14-
pub use scalar::*;
15-
#[cfg(target_feature = "avx2")]
16-
mod vectorized;
17-
#[cfg(target_feature = "avx2")]
18-
pub use vectorized::*;
11+
cfg_select! {
12+
target_feature = "avx2" => {
13+
mod vectorized;
14+
pub use vectorized::*;
15+
}
16+
_ => {
17+
mod scalar;
18+
pub use scalar::*;
19+
}
20+
}
1921

2022
#[derive(Copy, Clone)]
2123
#[repr(transparent)]
@@ -88,10 +90,10 @@ impl ThreatAccumulator {
8890
}
8991
}
9092

91-
#[cfg(target_feature = "avx512f")]
92-
const REGISTERS: usize = L1_SIZE / simd::I16_LANES;
93-
#[cfg(not(target_feature = "avx512f"))]
94-
const REGISTERS: usize = 8;
93+
const REGISTERS: usize = cfg_select! {
94+
target_feature = "avx512f" => L1_SIZE / simd::I16_LANES,
95+
_ => 8
96+
};
9597

9698
unsafe {
9799
for offset in (0..L1_SIZE).step_by(REGISTERS * simd::I16_LANES) {
@@ -153,10 +155,10 @@ impl ThreatAccumulator {
153155
}
154156
}
155157

156-
#[cfg(target_feature = "avx512f")]
157-
const REGISTERS: usize = L1_SIZE / simd::I16_LANES;
158-
#[cfg(not(target_feature = "avx512f"))]
159-
const REGISTERS: usize = 8;
158+
const REGISTERS: usize = cfg_select! {
159+
target_feature = "avx512f" => L1_SIZE / simd::I16_LANES,
160+
_ => 8
161+
};
160162

161163
let mut registers: [_; REGISTERS] = std::mem::zeroed();
162164

src/nnue/accumulator/threats/vectorized.rs

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ use crate::{
44
types::{Piece, Square},
55
};
66

7-
#[cfg(target_feature = "avx512vbmi2")]
8-
mod avx512;
9-
#[cfg(target_feature = "avx512vbmi2")]
10-
use avx512::*;
11-
12-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
13-
mod avx2;
14-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
15-
use avx2::*;
7+
cfg_select! {
8+
target_feature = "avx512vbmi2" => {
9+
mod avx512;
10+
use avx512::*;
11+
}
12+
target_feature = "avx2" => {
13+
mod avx2;
14+
use avx2::*;
15+
}
16+
}
1617

1718
const RAY_PERMUTATIONS: [[u8; 64]; 64] = {
1819
const OFFSETS: [u8; 64] = [
@@ -107,10 +108,14 @@ const RAY_SLIDERS_MASK: [u8; 64] = {
107108
};
108109

109110
pub fn push_threats_on_change(accum: &mut ThreatAccumulator, board: &Board, piece: Piece, square: Square, add: bool) {
110-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
111-
let board = unsafe { board.mailbox_vector_avx2() };
112-
#[cfg(target_feature = "avx512vbmi2")]
113-
let board = unsafe { board.mailbox_vector_avx512() };
111+
let board = unsafe { cfg_select! {
112+
target_feature = "avx512vbmi2" => {
113+
board.mailbox_vector_avx512()
114+
}
115+
target_feature = "avx2" => {
116+
board.mailbox_avx2()
117+
}
118+
} };
114119

115120
let (perm, valid) = ray_permutation(square);
116121
let (pboard, rays) = board_to_rays(perm, valid, board);
@@ -128,10 +133,14 @@ pub fn push_threats_on_change(accum: &mut ThreatAccumulator, board: &Board, piec
128133
}
129134

130135
pub fn push_threats_on_move(accum: &mut ThreatAccumulator, board: &Board, piece: Piece, src: Square, dst: Square) {
131-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
132-
let board = unsafe { board.mailbox_vector_avx2() };
133-
#[cfg(target_feature = "avx512vbmi2")]
134-
let board = unsafe { board.mailbox_vector_avx512() };
136+
let board = unsafe { cfg_select! {
137+
target_feature = "avx512vbmi2" => {
138+
board.mailbox_vector_avx512()
139+
}
140+
target_feature = "avx2" => {
141+
board.mailbox_avx2()
142+
}
143+
} };
135144

136145
let (src_perm, src_valid) = ray_permutation(src);
137146
let (dst_perm, dst_valid) = ray_permutation(dst);
@@ -173,10 +182,14 @@ pub fn push_threats_on_move(accum: &mut ThreatAccumulator, board: &Board, piece:
173182
pub fn push_threats_on_mutate(
174183
accum: &mut ThreatAccumulator, board: &Board, old_piece: Piece, new_piece: Piece, square: Square,
175184
) {
176-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
177-
let board = unsafe { board.mailbox_vector_avx2() };
178-
#[cfg(target_feature = "avx512vbmi2")]
179-
let board = unsafe { board.mailbox_vector_avx512() };
185+
let board = unsafe { cfg_select! {
186+
target_feature = "avx512vbmi2" => {
187+
board.mailbox_vector_avx512()
188+
}
189+
target_feature = "avx2" => {
190+
board.mailbox_avx2()
191+
}
192+
} };
180193

181194
let (perm, valid) = ray_permutation(square);
182195
let (pboard, rays) = board_to_rays(perm, valid, board);

0 commit comments

Comments
 (0)