Skip to content

Commit cebc2ca

Browse files
committed
Add opaque masks
1 parent a69c441 commit cebc2ca

18 files changed

+379
-146
lines changed

crates/core_simd/src/fmt.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ macro_rules! impl_fmt_trait {
7474

7575
impl_fmt_trait! {
7676
integers:
77-
crate::u8x8, crate::u8x16, crate::u8x32, crate::u8x64,
78-
crate::i8x8, crate::i8x16, crate::i8x32, crate::i8x64,
79-
crate::u16x4, crate::u16x8, crate::u16x16, crate::u16x32,
80-
crate::i16x4, crate::i16x8, crate::i16x16, crate::i16x32,
77+
crate::u8x8, crate::u8x16, crate::u8x32, crate::u8x64,
78+
crate::i8x8, crate::i8x16, crate::i8x32, crate::i8x64,
79+
crate::u16x4, crate::u16x8, crate::u16x16, crate::u16x32,
80+
crate::i16x4, crate::i16x8, crate::i16x16, crate::i16x32,
8181
crate::u32x2, crate::u32x4, crate::u32x8, crate::u32x16,
8282
crate::i32x2, crate::i32x4, crate::i32x8, crate::i32x16,
8383
crate::u64x2, crate::u64x4, crate::u64x8,
@@ -96,10 +96,10 @@ impl_fmt_trait! {
9696

9797
impl_fmt_trait! {
9898
masks:
99-
crate::mask8x8, crate::mask8x16, crate::mask8x32, crate::mask8x64,
100-
crate::mask16x4, crate::mask16x8, crate::mask16x16, crate::mask16x32,
101-
crate::mask32x2, crate::mask32x4, crate::mask32x8, crate::mask32x16,
102-
crate::mask64x2, crate::mask64x4, crate::mask64x8,
103-
crate::mask128x2, crate::mask128x4,
104-
crate::masksizex2, crate::masksizex4, crate::masksizex8,
99+
crate::masks::wide::m8x8, crate::masks::wide::m8x16, crate::masks::wide::m8x32, crate::masks::wide::m8x64,
100+
crate::masks::wide::m16x4, crate::masks::wide::m16x8, crate::masks::wide::m16x16, crate::masks::wide::m16x32,
101+
crate::masks::wide::m32x2, crate::masks::wide::m32x4, crate::masks::wide::m32x8, crate::masks::wide::m32x16,
102+
crate::masks::wide::m64x2, crate::masks::wide::m64x4, crate::masks::wide::m64x8,
103+
crate::masks::wide::m128x2, crate::masks::wide::m128x4,
104+
crate::masks::wide::msizex2, crate::masks::wide::msizex4, crate::masks::wide::msizex8,
105105
}

crates/core_simd/src/lib.rs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ mod fmt;
1010
mod intrinsics;
1111
mod ops;
1212

13-
mod masks;
14-
pub use masks::*;
13+
pub mod masks;
1514

1615
mod vectors_u8;
1716
pub use vectors_u8::*;
@@ -44,17 +43,4 @@ pub use vectors_f32::*;
4443
mod vectors_f64;
4544
pub use vectors_f64::*;
4645

47-
mod vectors_mask8;
48-
pub use vectors_mask8::*;
49-
mod vectors_mask16;
50-
pub use vectors_mask16::*;
51-
mod vectors_mask32;
52-
pub use vectors_mask32::*;
53-
mod vectors_mask64;
54-
pub use vectors_mask64::*;
55-
mod vectors_mask128;
56-
pub use vectors_mask128::*;
57-
mod vectors_masksize;
58-
pub use vectors_masksize::*;
59-
6046
mod round;

crates/core_simd/src/macros.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ macro_rules! define_float_vector {
314314
}
315315
}
316316

317-
318317
/// Defines an integer vector `$name` containing multiple `$lanes` of integer `$type`.
319318
macro_rules! define_integer_vector {
320319
{ $(#[$attr:meta])* struct $name:ident([$type:ty; $lanes:tt]); } => {
@@ -336,6 +335,7 @@ macro_rules! define_mask_vector {
336335
impl $name {
337336
call_repeat! { $lanes => define_mask_vector [$impl_type] splat $type | }
338337
call_counting_args! { $lanes => define_mask_vector => new $type | }
338+
call_counting_args! { $lanes => define_mask_vector => new_from_bool $type | }
339339
}
340340

341341
base_vector_traits! { $name => [$type; $lanes] }
@@ -361,5 +361,14 @@ macro_rules! define_mask_vector {
361361
pub const fn new($($var: $type),*) -> Self {
362362
Self($($var.0),*)
363363
}
364+
};
365+
{ new_from_bool $type:ty | $($var:ident)* } => {
366+
/// Used internally (since we can't use the Into trait in `const fn`s)
367+
#[allow(clippy::too_many_arguments)]
368+
#[allow(unused)]
369+
#[inline]
370+
pub(crate) const fn new_from_bool($($var: bool),*) -> Self {
371+
Self($(<$type>::new($var).0),*)
372+
}
364373
}
365374
}

crates/core_simd/src/masks/mod.rs

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
//! Types and traits associated with masking lanes of vectors.
2+
3+
pub mod wide;
4+
5+
trait MaskImpl {
6+
type Mask;
7+
}
8+
9+
impl MaskImpl for [u8; 8] {
10+
type Mask = wide::m8x8;
11+
}
12+
13+
impl MaskImpl for [u8; 16] {
14+
type Mask = wide::m8x16;
15+
}
16+
17+
impl MaskImpl for [u8; 32] {
18+
type Mask = wide::m8x32;
19+
}
20+
21+
impl MaskImpl for [u8; 64] {
22+
type Mask = wide::m8x64;
23+
}
24+
25+
impl MaskImpl for [u16; 4] {
26+
type Mask = wide::m16x4;
27+
}
28+
29+
impl MaskImpl for [u16; 8] {
30+
type Mask = wide::m16x8;
31+
}
32+
33+
impl MaskImpl for [u16; 16] {
34+
type Mask = wide::m16x16;
35+
}
36+
37+
impl MaskImpl for [u16; 32] {
38+
type Mask = wide::m16x32;
39+
}
40+
41+
impl MaskImpl for [u32; 2] {
42+
type Mask = wide::m32x2;
43+
}
44+
45+
impl MaskImpl for [u32; 4] {
46+
type Mask = wide::m32x4;
47+
}
48+
49+
impl MaskImpl for [u32; 8] {
50+
type Mask = wide::m32x8;
51+
}
52+
53+
impl MaskImpl for [u32; 16] {
54+
type Mask = wide::m32x16;
55+
}
56+
57+
impl MaskImpl for [u64; 2] {
58+
type Mask = wide::m64x2;
59+
}
60+
61+
impl MaskImpl for [u64; 4] {
62+
type Mask = wide::m64x4;
63+
}
64+
65+
impl MaskImpl for [u64; 8] {
66+
type Mask = wide::m64x8;
67+
}
68+
69+
impl MaskImpl for [u128; 2] {
70+
type Mask = wide::m128x2;
71+
}
72+
73+
impl MaskImpl for [u128; 4] {
74+
type Mask = wide::m128x4;
75+
}
76+
77+
impl MaskImpl for [usize; 2] {
78+
type Mask = wide::msizex2;
79+
}
80+
81+
impl MaskImpl for [usize; 4] {
82+
type Mask = wide::msizex4;
83+
}
84+
85+
impl MaskImpl for [usize; 8] {
86+
type Mask = wide::msizex8;
87+
}
88+
89+
macro_rules! define_opaque_mask {
90+
{
91+
$(#[$attr:meta])*
92+
struct $name:ident([$width:ty; $lanes:tt]);
93+
} => {
94+
$(#[$attr])*
95+
#[allow(non_camel_case_types)]
96+
pub struct $name(<[$width; $lanes] as MaskImpl>::Mask);
97+
98+
impl $name {
99+
/// Construct a mask by setting all lanes to the given value.
100+
pub fn splat(value: bool) -> Self {
101+
Self(<[$width; $lanes] as MaskImpl>::Mask::splat(value.into()))
102+
}
103+
104+
call_counting_args! { $lanes => define_opaque_mask => new [$width; $lanes] }
105+
}
106+
};
107+
{ new [$width:ty; $lanes:tt] $($var:ident)* } => {
108+
/// Construct a vector by setting each lane to the given values.
109+
#[allow(clippy::too_many_arguments)]
110+
#[inline]
111+
pub const fn new($($var: bool),*) -> Self {
112+
Self(<[$width; $lanes] as MaskImpl>::Mask::new_from_bool($($var),*))
113+
}
114+
}
115+
}
116+
117+
define_opaque_mask! {
118+
/// Mask for 8 8-bit lanes
119+
struct mask8x8([u8; 8]);
120+
}
121+
122+
define_opaque_mask! {
123+
/// Mask for 16 8-bit lanes
124+
struct mask8x16([u8; 16]);
125+
}
126+
127+
define_opaque_mask! {
128+
/// Mask for 32 8-bit lanes
129+
struct mask8x32([u8; 32]);
130+
}
131+
132+
define_opaque_mask! {
133+
/// Mask for 64 8-bit lanes
134+
struct mask8x64([u8; 64]);
135+
}
136+
137+
define_opaque_mask! {
138+
/// Mask for 4 16-bit lanes
139+
struct mask16x4([u16; 4]);
140+
}
141+
142+
define_opaque_mask! {
143+
/// Mask for 8 16-bit lanes
144+
struct mask16x8([u16; 8]);
145+
}
146+
147+
define_opaque_mask! {
148+
/// Mask for 16 16-bit lanes
149+
struct mask16x16([u16; 16]);
150+
}
151+
152+
define_opaque_mask! {
153+
/// Mask for 32 16-bit lanes
154+
struct mask16x32([u16; 32]);
155+
}
156+
157+
define_opaque_mask! {
158+
/// Mask for 2 32-bit lanes
159+
struct mask32x2([u32; 2]);
160+
}
161+
162+
define_opaque_mask! {
163+
/// Mask for 4 32-bit lanes
164+
struct mask32x4([u32; 4]);
165+
}
166+
167+
define_opaque_mask! {
168+
/// Mask for 8 32-bit lanes
169+
struct mask32x8([u32; 8]);
170+
}
171+
172+
define_opaque_mask! {
173+
/// Mask for 16 32-bit lanes
174+
struct mask32x16([u32; 16]);
175+
}
176+
177+
define_opaque_mask! {
178+
/// Mask for 2 64-bit lanes
179+
struct mask64x2([u64; 2]);
180+
}
181+
182+
define_opaque_mask! {
183+
/// Mask for 4 64-bit lanes
184+
struct mask64x4([u64; 4]);
185+
}
186+
187+
define_opaque_mask! {
188+
/// Mask for 8 64-bit lanes
189+
struct mask64x8([u64; 8]);
190+
}
191+
192+
define_opaque_mask! {
193+
/// Mask for 2 128-bit lanes
194+
struct mask128x2([u128; 2]);
195+
}
196+
197+
define_opaque_mask! {
198+
/// Mask for 4 128-bit lanes
199+
struct mask128x4([u128; 4]);
200+
}
201+
202+
define_opaque_mask! {
203+
/// Mask for 2 `isize`-wide lanes
204+
struct masksizex2([usize; 2]);
205+
}
206+
207+
define_opaque_mask! {
208+
/// Mask for 4 `isize`-wide lanes
209+
struct masksizex4([usize; 4]);
210+
}
211+
212+
define_opaque_mask! {
213+
/// Mask for 8 `isize`-wide lanes
214+
struct masksizex8([usize; 8]);
215+
}

crates/core_simd/src/masks.rs renamed to crates/core_simd/src/masks/wide/mod.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
//! Masks that take up full vector registers.
2+
3+
mod vectors_m8;
4+
pub use vectors_m8::*;
5+
mod vectors_m16;
6+
pub use vectors_m16::*;
7+
mod vectors_m32;
8+
pub use vectors_m32::*;
9+
mod vectors_m64;
10+
pub use vectors_m64::*;
11+
mod vectors_m128;
12+
pub use vectors_m128::*;
13+
mod vectors_msize;
14+
pub use vectors_msize::*;
15+
116
/// The error type returned when converting an integer to a mask fails.
217
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
318
pub struct TryFromMaskError(());
@@ -95,30 +110,30 @@ macro_rules! define_mask {
95110

96111
define_mask! {
97112
/// 8-bit mask
98-
struct mask8(i8);
113+
struct m8(i8);
99114
}
100115

101116
define_mask! {
102117
/// 16-bit mask
103-
struct mask16(i16);
118+
struct m16(i16);
104119
}
105120

106121
define_mask! {
107122
/// 32-bit mask
108-
struct mask32(i32);
123+
struct m32(i32);
109124
}
110125

111126
define_mask! {
112127
/// 64-bit mask
113-
struct mask64(i64);
128+
struct m64(i64);
114129
}
115130

116131
define_mask! {
117132
/// 128-bit mask
118-
struct mask128(i128);
133+
struct m128(i128);
119134
}
120135

121136
define_mask! {
122137
/// `isize`-wide mask
123-
struct masksize(isize);
138+
struct msize(isize);
124139
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
use super::m128;
2+
3+
define_mask_vector! {
4+
/// Vector of two `m128` values
5+
struct m128x2([i128 as m128; 2]);
6+
}
7+
8+
define_mask_vector! {
9+
/// Vector of four `m128` values
10+
struct m128x4([i128 as m128; 4]);
11+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use super::m16;
2+
3+
define_mask_vector! {
4+
/// Vector of four `m16` values
5+
struct m16x4([i16 as m16; 4]);
6+
}
7+
8+
define_mask_vector! {
9+
/// Vector of eight `m16` values
10+
struct m16x8([i16 as m16; 8]);
11+
}
12+
13+
define_mask_vector! {
14+
/// Vector of 16 `m16` values
15+
struct m16x16([i16 as m16; 16]);
16+
}
17+
18+
define_mask_vector! {
19+
/// Vector of 32 `m16` values
20+
struct m16x32([i16 as m16; 32]);
21+
}

0 commit comments

Comments
 (0)