Skip to content

Commit cd7d73f

Browse files
committed
Use bitmask trait
1 parent 4910274 commit cd7d73f

File tree

5 files changed

+157
-25
lines changed

5 files changed

+157
-25
lines changed

crates/core_simd/src/masks.rs

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
)]
1313
mod mask_impl;
1414

15-
use crate::simd::intrinsics;
16-
use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
15+
mod to_bitmask;
16+
pub use to_bitmask::ToBitMask;
17+
18+
#[cfg(feature = "generic_const_exprs")]
19+
pub use to_bitmask::bitmask_len;
20+
21+
use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount};
1722
use core::cmp::Ordering;
1823
use core::{fmt, mem};
1924

@@ -216,22 +221,6 @@ where
216221
}
217222
}
218223

219-
/// Convert this mask to a bitmask, with one bit set per lane.
220-
#[cfg(feature = "generic_const_exprs")]
221-
#[inline]
222-
#[must_use = "method returns a new array and does not mutate the original value"]
223-
pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
224-
self.0.to_bitmask()
225-
}
226-
227-
/// Convert a bitmask to a mask.
228-
#[cfg(feature = "generic_const_exprs")]
229-
#[inline]
230-
#[must_use = "method returns a new mask and does not mutate the original value"]
231-
pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
232-
Self(mask_impl::Mask::from_bitmask(bitmask))
233-
}
234-
235224
/// Returns true if any lane is set, or false otherwise.
236225
#[inline]
237226
#[must_use = "method returns a new bool and does not mutate the original value"]

crates/core_simd/src/masks/bitmask.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,29 @@ where
118118
#[cfg(feature = "generic_const_exprs")]
119119
#[inline]
120120
#[must_use = "method returns a new array and does not mutate the original value"]
121-
pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
121+
pub fn to_bitmask(self) -> [u8; super::bitmask_len(LANES)] {
122122
// Safety: these are the same type and we are laundering the generic
123123
unsafe { core::mem::transmute_copy(&self.0) }
124124
}
125125

126126
#[cfg(feature = "generic_const_exprs")]
127127
#[inline]
128128
#[must_use = "method returns a new mask and does not mutate the original value"]
129-
pub fn from_bitmask(bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
129+
pub fn from_bitmask(bitmask: [u8; super::bitmask_len(LANES)]) -> Self {
130130
// Safety: these are the same type and we are laundering the generic
131131
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
132132
}
133133

134+
#[inline]
135+
pub unsafe fn to_bitmask_intrinsic<U>(self) -> U {
136+
unsafe { core::mem::transmute_copy(&self.0) }
137+
}
138+
139+
#[inline]
140+
pub unsafe fn from_bitmask_intrinsic<U>(bitmask: U) -> Self {
141+
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
142+
}
143+
134144
#[inline]
135145
#[must_use = "method returns a new mask and does not mutate the original value"]
136146
pub fn convert<U>(self) -> Mask<U, LANES>

crates/core_simd/src/masks/full_masks.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ where
112112
#[cfg(feature = "generic_const_exprs")]
113113
#[inline]
114114
#[must_use = "method returns a new array and does not mutate the original value"]
115-
pub fn to_bitmask(self) -> [u8; LaneCount::<LANES>::BITMASK_LEN] {
115+
pub fn to_bitmask(self) -> [u8; super::bitmask_len(LANES)] {
116116
unsafe {
117-
let mut bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN] =
117+
let mut bitmask: [u8; super::bitmask_len(LANES)] =
118118
intrinsics::simd_bitmask(self.0);
119119

120120
// There is a bug where LLVM appears to implement this operation with the wrong
@@ -133,7 +133,7 @@ where
133133
#[cfg(feature = "generic_const_exprs")]
134134
#[inline]
135135
#[must_use = "method returns a new mask and does not mutate the original value"]
136-
pub fn from_bitmask(mut bitmask: [u8; LaneCount::<LANES>::BITMASK_LEN]) -> Self {
136+
pub fn from_bitmask(mut bitmask: [u8; super::bitmask_len(LANES)]) -> Self {
137137
unsafe {
138138
// There is a bug where LLVM appears to implement this operation with the wrong
139139
// bit order.
@@ -152,6 +152,24 @@ where
152152
}
153153
}
154154

155+
#[inline]
156+
pub unsafe fn to_bitmask_intrinsic<U>(self) -> U {
157+
// Safety: caller must only return bitmask types
158+
unsafe { intrinsics::simd_bitmask(self.0) }
159+
}
160+
161+
#[inline]
162+
pub unsafe fn from_bitmask_intrinsic<U>(bitmask: U) -> Self {
163+
// Safety: caller must only pass bitmask types
164+
unsafe {
165+
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
166+
bitmask,
167+
Self::splat(true).to_int(),
168+
Self::splat(false).to_int(),
169+
))
170+
}
171+
}
172+
155173
#[inline]
156174
#[must_use = "method returns a new bool and does not mutate the original value"]
157175
pub fn any(self) -> bool {
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use super::{mask_impl, Mask, MaskElement};
2+
use crate::{LaneCount, SupportedLaneCount};
3+
4+
/// Converts masks to and from bitmasks.
5+
///
6+
/// In a bitmask, each bit represents if the corresponding lane in the mask is set.
7+
pub trait ToBitMask<BitMask> {
8+
/// Converts a mask to a bitmask.
9+
fn to_bitmask(self) -> BitMask;
10+
11+
/// Converts a bitmask to a mask.
12+
fn from_bitmask(bitmask: BitMask) -> Self;
13+
}
14+
15+
macro_rules! impl_integer_intrinsic {
16+
{ $(unsafe impl ToBitMask<$int:ty> for Mask<_, $lanes:literal>)* } => {
17+
$(
18+
impl<T: MaskElement> ToBitMask<$int> for Mask<T, $lanes> {
19+
fn to_bitmask(self) -> $int {
20+
unsafe { self.0.to_bitmask_intrinsic() }
21+
}
22+
23+
fn from_bitmask(bitmask: $int) -> Self {
24+
unsafe { Self(mask_impl::Mask::from_bitmask_intrinsic(bitmask)) }
25+
}
26+
}
27+
)*
28+
}
29+
}
30+
31+
impl_integer_intrinsic! {
32+
unsafe impl ToBitMask<u8> for Mask<_, 8>
33+
unsafe impl ToBitMask<u16> for Mask<_, 16>
34+
unsafe impl ToBitMask<u32> for Mask<_, 32>
35+
unsafe impl ToBitMask<u64> for Mask<_, 64>
36+
}
37+
38+
macro_rules! impl_integer_via {
39+
{ $(impl ToBitMask<$int:ty, via $via:ty> for Mask<_, $lanes:literal>)* } => {
40+
$(
41+
impl<T: MaskElement> ToBitMask<$int> for Mask<T, $lanes> {
42+
fn to_bitmask(self) -> $int {
43+
let bitmask: $via = self.to_bitmask();
44+
bitmask as _
45+
}
46+
47+
fn from_bitmask(bitmask: $int) -> Self {
48+
Self::from_bitmask(bitmask as $via)
49+
}
50+
}
51+
)*
52+
}
53+
}
54+
55+
impl_integer_via! {
56+
impl ToBitMask<u16, via u8> for Mask<_, 8>
57+
impl ToBitMask<u32, via u8> for Mask<_, 8>
58+
impl ToBitMask<u64, via u8> for Mask<_, 8>
59+
60+
impl ToBitMask<u32, via u16> for Mask<_, 16>
61+
impl ToBitMask<u64, via u16> for Mask<_, 16>
62+
63+
impl ToBitMask<u64, via u32> for Mask<_, 32>
64+
}
65+
66+
#[cfg(target_pointer_width = "32")]
67+
impl_integer_via! {
68+
impl ToBitMask<usize, via u8> for Mask<_, 8>
69+
impl ToBitMask<usize, via u16> for Mask<_, 16>
70+
impl ToBitMask<usize, via u32> for Mask<_, 32>
71+
}
72+
73+
#[cfg(target_pointer_width = "64")]
74+
impl_integer_via! {
75+
impl ToBitMask<usize, via u8> for Mask<_, 8>
76+
impl ToBitMask<usize, via u16> for Mask<_, 16>
77+
impl ToBitMask<usize, via u32> for Mask<_, 32>
78+
impl ToBitMask<usize, via u64> for Mask<_, 64>
79+
}
80+
81+
/// Returns the minimum numnber of bytes in a bitmask with `lanes` lanes.
82+
#[cfg(feature = "generic_const_exprs")]
83+
pub const fn bitmask_len(lanes: usize) -> usize {
84+
(lanes + 7) / 8
85+
}
86+
87+
#[cfg(feature = "generic_const_exprs")]
88+
impl<T: MaskElement, const LANES: usize> ToBitMask<[u8; bitmask_len(LANES)]> for Mask<T, LANES>
89+
where
90+
LaneCount<LANES>: SupportedLaneCount,
91+
{
92+
fn to_bitmask(self) -> [u8; bitmask_len(LANES)] {
93+
self.0.to_bitmask()
94+
}
95+
96+
fn from_bitmask(bitmask: [u8; bitmask_len(LANES)]) -> Self {
97+
Mask(mask_impl::Mask::from_bitmask(bitmask))
98+
}
99+
}

crates/core_simd/tests/masks.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,34 @@ macro_rules! test_mask_api {
6868
assert_eq!(core_simd::Mask::<$type, 8>::from_int(int), mask);
6969
}
7070

71+
#[test]
72+
fn roundtrip_bitmask_array_conversion() {
73+
use core_simd::ToBitMask;
74+
let values = [
75+
true, false, false, true, false, false, true, false,
76+
true, true, false, false, false, false, false, true,
77+
];
78+
let mask = core_simd::Mask::<$type, 16>::from_array(values);
79+
let bitmask: u16 = mask.to_bitmask();
80+
assert_eq!(bitmask, 0b1000001101001001);
81+
assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask);
82+
}
83+
84+
/*
7185
#[cfg(feature = "generic_const_exprs")]
7286
#[test]
73-
fn roundtrip_bitmask_conversion() {
87+
fn roundtrip_bitmask_array_conversion() {
88+
use core_simd::ToBitMask;
7489
let values = [
7590
true, false, false, true, false, false, true, false,
7691
true, true, false, false, false, false, false, true,
7792
];
7893
let mask = core_simd::Mask::<$type, 16>::from_array(values);
79-
let bitmask = mask.to_bitmask();
94+
let bitmask: [u8; 2] = mask.to_bitmask();
8095
assert_eq!(bitmask, [0b01001001, 0b10000011]);
8196
assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask);
8297
}
98+
*/
8399
}
84100
}
85101
}

0 commit comments

Comments
 (0)