Skip to content

Commit 36f6c24

Browse files
committed
Use cfg_select! for code based on target_feature
Ups the Rust requirement to 1.95, so may be a while before this could be merged. But is a _much_ nicer way of expressing this pattern.
1 parent b2ad933 commit 36f6c24

7 files changed

Lines changed: 256 additions & 229 deletions

File tree

src/nnue.rs

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,37 @@ use crate::{
1212
};
1313

1414
mod forward {
15-
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
16-
mod vectorized;
17-
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
18-
pub use vectorized::*;
19-
20-
#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))]
21-
mod scalar;
22-
#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))]
23-
pub use scalar::*;
15+
cfg_select! {
16+
any(target_feature = "avx2", target_feature = "neon") => {
17+
mod vectorized;
18+
pub use vectorized::*;
19+
}
20+
_ => {
21+
mod scalar;
22+
pub use scalar::*;
23+
}
24+
}
2425
}
2526

2627
mod simd {
27-
#[cfg(target_feature = "avx512f")]
28-
mod avx512;
29-
#[cfg(target_feature = "avx512f")]
30-
pub use avx512::*;
31-
32-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
33-
mod avx2;
34-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
35-
pub use avx2::*;
36-
37-
#[cfg(all(target_feature = "neon", not(any(target_feature = "avx2", target_feature = "avx512f"))))]
38-
mod neon;
39-
#[cfg(all(target_feature = "neon", not(any(target_feature = "avx2", target_feature = "avx512f"))))]
40-
pub use neon::*;
41-
42-
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2", target_feature = "neon")))]
43-
mod scalar;
44-
#[cfg(not(any(target_feature = "avx512f", target_feature = "avx2", target_feature = "neon")))]
45-
pub use scalar::*;
28+
cfg_select! {
29+
target_feature = "avx512f" => {
30+
mod avx512;
31+
pub use avx512::*;
32+
}
33+
target_feature = "avx2" => {
34+
mod avx2;
35+
pub use avx2::*;
36+
}
37+
target_feature = "neon" => {
38+
mod neon;
39+
pub use neon::*;
40+
}
41+
_ => {
42+
mod scalar;
43+
pub use scalar::*;
44+
}
45+
}
4646
}
4747

4848
const NETWORK_SCALE: i32 = 380;
@@ -57,10 +57,14 @@ const L3_SIZE: usize = 32;
5757
const FT_QUANT: i32 = 255;
5858
const L1_QUANT: i32 = 64;
5959

60-
#[cfg(target_feature = "avx512f")]
61-
const FT_SHIFT: u32 = 9;
62-
#[cfg(not(target_feature = "avx512f"))]
63-
const FT_SHIFT: i32 = 9;
60+
cfg_select! {
61+
target_feature = "avx512f" => {
62+
const FT_SHIFT: u32 = 9;
63+
}
64+
_ => {
65+
const FT_SHIFT: i32 = 9;
66+
}
67+
}
6468

6569
const DEQUANT_MULTIPLIER: f32 = (1 << FT_SHIFT) as f32 / (FT_QUANT * FT_QUANT * L1_QUANT) as f32;
6670

src/nnue/accumulator/psq.rs

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -66,39 +66,42 @@ impl PstAccumulator {
6666
self.accurate[pov] = true;
6767
}
6868

69-
#[inline]
70-
#[cfg(not(target_feature = "avx512vbmi2"))]
71-
fn push_features(
72-
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
73-
pov: Color,
74-
) {
75-
for square in bb {
76-
features.push(pst_index(color, piece_type, square, king, pov));
69+
cfg_select! {
70+
target_feature = "avx512vbmi2" => {
71+
#[inline]
72+
fn push_features(
73+
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
74+
pov: Color,
75+
) {
76+
unsafe {
77+
use std::arch::x86_64::*;
78+
79+
let base = pst_index(color, piece_type, Square::new(0), king, pov);
80+
81+
let iota = _mm512_set_epi8(
82+
63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38,
83+
37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12,
84+
11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
85+
);
86+
let squares = _mm512_castsi512_si128(_mm512_maskz_compress_epi8(bb.0, iota));
87+
let to_write = _mm256_xor_si256(_mm256_set1_epi16(base as i16), _mm256_cvtepu8_epi16(squares));
88+
features.unchecked_write(|data| {
89+
_mm256_storeu_si256(data.cast(), to_write);
90+
bb.count()
91+
});
92+
}
93+
}
7794
}
78-
}
79-
80-
#[inline]
81-
#[cfg(target_feature = "avx512vbmi2")]
82-
fn push_features(
83-
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
84-
pov: Color,
85-
) {
86-
unsafe {
87-
use std::arch::x86_64::*;
88-
89-
let base = pst_index(color, piece_type, Square::new(0), king, pov);
90-
91-
let iota = _mm512_set_epi8(
92-
63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38,
93-
37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12,
94-
11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
95-
);
96-
let squares = _mm512_castsi512_si128(_mm512_maskz_compress_epi8(bb.0, iota));
97-
let to_write = _mm256_xor_si256(_mm256_set1_epi16(base as i16), _mm256_cvtepu8_epi16(squares));
98-
features.unchecked_write(|data| {
99-
_mm256_storeu_si256(data.cast(), to_write);
100-
bb.count()
101-
});
95+
_ => {
96+
#[inline]
97+
fn push_features(
98+
features: &mut ArrayVec<PstFeature, 64>, color: Color, piece_type: PieceType, bb: Bitboard, king: Square,
99+
pov: Color,
100+
) {
101+
for square in bb {
102+
features.push(pst_index(color, piece_type, square, king, pov));
103+
}
104+
}
102105
}
103106
}
104107

src/nnue/accumulator/threats.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ use crate::{
88
mod threat_index;
99
pub use threat_index::*;
1010

11-
#[cfg(not(target_feature = "avx2"))]
12-
mod scalar;
13-
#[cfg(not(target_feature = "avx2"))]
14-
pub use scalar::*;
15-
#[cfg(target_feature = "avx2")]
16-
mod vectorized;
17-
#[cfg(target_feature = "avx2")]
18-
pub use vectorized::*;
11+
cfg_select! {
12+
target_feature = "avx2" => {
13+
mod vectorized;
14+
pub use vectorized::*;
15+
}
16+
_ => {
17+
mod scalar;
18+
pub use scalar::*;
19+
}
20+
}
1921

2022
#[derive(Copy, Clone)]
2123
#[repr(transparent)]
@@ -88,10 +90,10 @@ impl ThreatAccumulator {
8890
}
8991
}
9092

91-
#[cfg(target_feature = "avx512f")]
92-
const REGISTERS: usize = L1_SIZE / simd::I16_LANES;
93-
#[cfg(not(target_feature = "avx512f"))]
94-
const REGISTERS: usize = 8;
93+
const REGISTERS: usize = cfg_select! {
94+
target_feature = "avx512f" => L1_SIZE / simd::I16_LANES,
95+
_ => 8
96+
};
9597

9698
unsafe {
9799
for offset in (0..L1_SIZE).step_by(REGISTERS * simd::I16_LANES) {
@@ -153,10 +155,10 @@ impl ThreatAccumulator {
153155
}
154156
}
155157

156-
#[cfg(target_feature = "avx512f")]
157-
const REGISTERS: usize = L1_SIZE / simd::I16_LANES;
158-
#[cfg(not(target_feature = "avx512f"))]
159-
const REGISTERS: usize = 8;
158+
const REGISTERS: usize = cfg_select! {
159+
target_feature = "avx512f" => L1_SIZE / simd::I16_LANES,
160+
_ => 8
161+
};
160162

161163
let mut registers: [_; REGISTERS] = std::mem::zeroed();
162164

src/nnue/accumulator/threats/vectorized.rs

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ use crate::{
44
types::{Piece, Square},
55
};
66

7-
#[cfg(target_feature = "avx512vbmi2")]
8-
mod avx512;
9-
#[cfg(target_feature = "avx512vbmi2")]
10-
use avx512::*;
11-
12-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
13-
mod avx2;
14-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
15-
use avx2::*;
7+
cfg_select! {
8+
target_feature = "avx512vbmi2" => {
9+
mod avx512;
10+
use avx512::*;
11+
}
12+
target_feature = "avx2" => {
13+
mod avx2;
14+
use avx2::*;
15+
}
16+
}
1617

1718
const RAY_PERMUTATIONS: [[u8; 64]; 64] = {
1819
const OFFSETS: [u8; 64] = [
@@ -107,10 +108,10 @@ const RAY_SLIDERS_MASK: [u8; 64] = {
107108
};
108109

109110
pub fn push_threats_on_change(accum: &mut ThreatAccumulator, board: &Board, piece: Piece, square: Square, add: bool) {
110-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
111-
let board = unsafe { board.mailbox_vector_avx2() };
112-
#[cfg(target_feature = "avx512vbmi2")]
113-
let board = unsafe { board.mailbox_vector_avx512() };
111+
let board = unsafe { cfg_select! {
112+
target_feature = "avx512vbmi2" => board.mailbox_vector_avx512()
113+
target_feature = "avx2" => board.mailbox_avx2()
114+
} };
114115

115116
let (perm, valid) = ray_permutation(square);
116117
let (pboard, rays) = board_to_rays(perm, valid, board);
@@ -128,10 +129,10 @@ pub fn push_threats_on_change(accum: &mut ThreatAccumulator, board: &Board, piec
128129
}
129130

130131
pub fn push_threats_on_move(accum: &mut ThreatAccumulator, board: &Board, piece: Piece, src: Square, dst: Square) {
131-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
132-
let board = unsafe { board.mailbox_vector_avx2() };
133-
#[cfg(target_feature = "avx512vbmi2")]
134-
let board = unsafe { board.mailbox_vector_avx512() };
132+
let board = unsafe { cfg_select! {
133+
target_feature = "avx512vbmi2" => board.mailbox_vector_avx512()
134+
target_feature = "avx2" => board.mailbox_avx2()
135+
} };
135136

136137
let (src_perm, src_valid) = ray_permutation(src);
137138
let (dst_perm, dst_valid) = ray_permutation(dst);
@@ -173,10 +174,10 @@ pub fn push_threats_on_move(accum: &mut ThreatAccumulator, board: &Board, piece:
173174
pub fn push_threats_on_mutate(
174175
accum: &mut ThreatAccumulator, board: &Board, old_piece: Piece, new_piece: Piece, square: Square,
175176
) {
176-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi2")))]
177-
let board = unsafe { board.mailbox_vector_avx2() };
178-
#[cfg(target_feature = "avx512vbmi2")]
179-
let board = unsafe { board.mailbox_vector_avx512() };
177+
let board = unsafe { cfg_select! {
178+
target_feature = "avx512vbmi2" => board.mailbox_vector_avx512()
179+
target_feature = "avx2" => board.mailbox_avx2()
180+
} };
180181

181182
let (perm, valid) = ray_permutation(square);
182183
let (pboard, rays) = board_to_rays(perm, valid, board);

0 commit comments

Comments
 (0)