Skip to content

Commit af6956c

Browse files
committed
Make implementation functions safe
1 parent 744169a commit af6956c

File tree

3 files changed

+100
-40
lines changed

3 files changed

+100
-40
lines changed

crates/core_simd/src/masks/bitmask.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(unused_imports)]
22
use super::MaskElement;
33
use crate::simd::intrinsics;
4-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
4+
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask, ToBitMaskArray};
55
use core::marker::PhantomData;
66

77
/// A mask where each lane is represented by a single bit.
@@ -115,31 +115,40 @@ where
115115
unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) }
116116
}
117117

118-
// Safety: N must be the exact number of bytes required to hold the bitmask for this mask
119118
#[inline]
120119
#[must_use = "method returns a new array and does not mutate the original value"]
121-
pub unsafe fn to_bitmask_array<const N: usize>(self) -> [u8; N] {
122-
// Safety: these are the same type and we are laundering the generic
120+
pub fn to_bitmask_array<const N: usize>(self) -> [u8; N] {
121+
assert!(core::mem::size_of::<Self>() == N);
122+
123+
// Safety: converting an integer to an array of bytes of the same size is safe
123124
unsafe { core::mem::transmute_copy(&self.0) }
124125
}
125126

126127
// Safety: N must be the exact number of bytes required to hold the bitmask for this mask
127128
#[inline]
128129
#[must_use = "method returns a new mask and does not mutate the original value"]
129-
pub unsafe fn from_bitmask_array<const N: usize>(bitmask: [u8; N]) -> Self {
130-
// Safety: these are the same type and we are laundering the generic
130+
pub fn from_bitmask_array<const N: usize>(bitmask: [u8; N]) -> Self {
131+
assert!(core::mem::size_of::<Self>() == N);
132+
133+
// Safety: converting an array of bytes to an integer of the same size is safe
131134
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
132135
}
133136

134-
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
135137
#[inline]
136-
pub unsafe fn to_bitmask_integer<U>(self) -> U {
138+
pub fn to_bitmask_integer<U>(self) -> U
139+
where
140+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
141+
{
142+
// Safety: these are the same types
137143
unsafe { core::mem::transmute_copy(&self.0) }
138144
}
139145

140-
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
141146
#[inline]
142-
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
147+
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
148+
where
149+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
150+
{
151+
// Safety: these are the same types
143152
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
144153
}
145154

crates/core_simd/src/masks/full_masks.rs

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use super::MaskElement;
44
use crate::simd::intrinsics;
5-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
5+
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask, ToBitMaskArray};
66

77
#[repr(transparent)]
88
pub struct Mask<T, const LANES: usize>(Simd<T, LANES>)
@@ -126,12 +126,26 @@ where
126126
unsafe { Mask(intrinsics::simd_cast(self.0)) }
127127
}
128128

129-
// Safety: N must be the exact number of bytes required to hold the bitmask for this mask
130129
#[inline]
131130
#[must_use = "method returns a new array and does not mutate the original value"]
132-
pub unsafe fn to_bitmask_array<const N: usize>(self) -> [u8; N] {
131+
pub fn to_bitmask_array<const N: usize>(self) -> [u8; N]
132+
where
133+
super::Mask<T, LANES>: ToBitMaskArray,
134+
[(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,
135+
{
136+
assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N);
137+
138+
// Safety: N is the correct bitmask size
139+
//
140+
// The transmute below allows this function to be marked safe, since it will prevent
141+
// monomorphization errors in the case of an incorrect size.
133142
unsafe {
134-
let mut bitmask: [u8; N] = intrinsics::simd_bitmask(self.0);
143+
// Compute the bitmask
144+
let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] =
145+
intrinsics::simd_bitmask(self.0);
146+
147+
// Transmute to the return type, previously asserted to be the same size
148+
let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask);
135149

136150
// There is a bug where LLVM appears to implement this operation with the wrong
137151
// bit order.
@@ -146,10 +160,19 @@ where
146160
}
147161
}
148162

149-
// Safety: N must be the exact number of bytes required to hold the bitmask for this mask
150163
#[inline]
151164
#[must_use = "method returns a new mask and does not mutate the original value"]
152-
pub unsafe fn from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self {
165+
pub fn from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self
166+
where
167+
super::Mask<T, LANES>: ToBitMaskArray,
168+
[(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,
169+
{
170+
assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N);
171+
172+
// Safety: N is the correct bitmask size
173+
//
174+
// The transmute below allows this function to be marked safe, since it will prevent
175+
// monomorphization errors in the case of an incorrect size.
153176
unsafe {
154177
// There is a bug where LLVM appears to implement this operation with the wrong
155178
// bit order.
@@ -160,6 +183,11 @@ where
160183
}
161184
}
162185

186+
// Transmute to the bitmask type, previously asserted to be the same size
187+
let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] =
188+
core::mem::transmute_copy(&bitmask);
189+
190+
// Compute the regular mask
163191
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
164192
bitmask,
165193
Self::splat(true).to_int(),
@@ -168,11 +196,12 @@ where
168196
}
169197
}
170198

171-
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
172-
// this mask
173199
#[inline]
174-
pub(crate) unsafe fn to_bitmask_integer<U: ReverseBits>(self) -> U {
175-
// Safety: caller must only return bitmask types
200+
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
201+
where
202+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
203+
{
204+
// Safety: U is required to be the appropriate bitmask type
176205
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
177206

178207
// There is a bug where LLVM appears to implement this operation with the wrong
@@ -188,7 +217,10 @@ where
188217
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
189218
// this mask
190219
#[inline]
191-
pub(crate) unsafe fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self {
220+
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
221+
where
222+
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
223+
{
192224
// There is a bug where LLVM appears to implement this operation with the wrong
193225
// bit order.
194226
// TODO fix this in a better way
@@ -198,7 +230,7 @@ where
198230
bitmask
199231
};
200232

201-
// Safety: caller must only pass bitmask types
233+
// Safety: U is required to be the appropriate bitmask type
202234
unsafe {
203235
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
204236
bitmask,

crates/core_simd/src/masks/to_bitmask.rs

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
use super::{mask_impl, Mask, MaskElement};
2+
use crate::simd::{LaneCount, SupportedLaneCount};
3+
4+
mod sealed {
5+
pub trait Sealed {}
6+
}
7+
pub use sealed::Sealed;
8+
9+
impl<T, const LANES: usize> Sealed for Mask<T, LANES>
10+
where
11+
T: MaskElement,
12+
LaneCount<LANES>: SupportedLaneCount,
13+
{
14+
}
215

316
/// Converts masks to and from integer bitmasks.
417
///
518
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB.
6-
pub trait ToBitMask {
19+
///
20+
/// # Safety
21+
/// This trait is `unsafe` and sealed, since the `BitMask` type must match the number of lanes in
22+
/// the mask.
23+
pub unsafe trait ToBitMask: Sealed {
724
/// The integer bitmask type.
825
type BitMask;
926

@@ -17,7 +34,11 @@ pub trait ToBitMask {
1734
/// Converts masks to and from byte array bitmasks.
1835
///
1936
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte.
20-
pub trait ToBitMaskArray {
37+
///
38+
/// # Safety
39+
/// This trait is `unsafe` and sealed, since the `BYTES` value must match the number of lanes in
40+
/// the mask.
41+
pub unsafe trait ToBitMaskArray: Sealed {
2142
/// The length of the bitmask array.
2243
const BYTES: usize;
2344

@@ -31,15 +52,15 @@ pub trait ToBitMaskArray {
3152
macro_rules! impl_integer_intrinsic {
3253
{ $(unsafe impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
3354
$(
34-
impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
55+
unsafe impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
3556
type BitMask = $int;
3657

3758
fn to_bitmask(self) -> $int {
38-
unsafe { self.0.to_bitmask_integer() }
59+
self.0.to_bitmask_integer()
3960
}
4061

4162
fn from_bitmask(bitmask: $int) -> Self {
42-
unsafe { Self(mask_impl::Mask::from_bitmask_integer(bitmask)) }
63+
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
4364
}
4465
}
4566
)*
@@ -59,20 +80,18 @@ pub const fn bitmask_len(lanes: usize) -> usize {
5980
}
6081

6182
macro_rules! impl_array_bitmask {
62-
{ $(impl ToBitMask<[u8; _]> for Mask<_, $lanes:literal>)* } => {
83+
{ $(impl ToBitMaskArray<[u8; _]> for Mask<_, $lanes:literal>)* } => {
6384
$(
64-
impl<T: MaskElement> ToBitMaskArray for Mask<T, $lanes>
85+
unsafe impl<T: MaskElement> ToBitMaskArray for Mask<T, $lanes>
6586
{
6687
const BYTES: usize = bitmask_len($lanes);
6788

6889
fn to_bitmask_array(self) -> [u8; Self::BYTES] {
69-
// Safety: BYTES is the exact size required
70-
unsafe { self.0.to_bitmask_array() }
90+
self.0.to_bitmask_array()
7191
}
7292

7393
fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self {
74-
// Safety: BYTES is the exact size required
75-
unsafe { Mask(mask_impl::Mask::from_bitmask_array(bitmask)) }
94+
Mask(mask_impl::Mask::from_bitmask_array(bitmask))
7695
}
7796
}
7897
)*
@@ -81,11 +100,11 @@ macro_rules! impl_array_bitmask {
81100

82101
// FIXME this should be specified generically, but it doesn't seem to work with rustc, yet
83102
impl_array_bitmask! {
84-
impl ToBitMask<[u8; _]> for Mask<_, 1>
85-
impl ToBitMask<[u8; _]> for Mask<_, 2>
86-
impl ToBitMask<[u8; _]> for Mask<_, 4>
87-
impl ToBitMask<[u8; _]> for Mask<_, 8>
88-
impl ToBitMask<[u8; _]> for Mask<_, 16>
89-
impl ToBitMask<[u8; _]> for Mask<_, 32>
90-
impl ToBitMask<[u8; _]> for Mask<_, 64>
103+
impl ToBitMaskArray<[u8; _]> for Mask<_, 1>
104+
impl ToBitMaskArray<[u8; _]> for Mask<_, 2>
105+
impl ToBitMaskArray<[u8; _]> for Mask<_, 4>
106+
impl ToBitMaskArray<[u8; _]> for Mask<_, 8>
107+
impl ToBitMaskArray<[u8; _]> for Mask<_, 16>
108+
impl ToBitMaskArray<[u8; _]> for Mask<_, 32>
109+
impl ToBitMaskArray<[u8; _]> for Mask<_, 64>
91110
}

0 commit comments

Comments
 (0)