Skip to content

Commit c761d6f

Browse files
author
Daniel Smith
committed
Add remaining gather intrinsics
1 parent 1c81797 commit c761d6f

File tree

4 files changed

+316
-1
lines changed

4 files changed

+316
-1
lines changed

crates/core_arch/src/simd.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ simd_ty!(u32x16[u32]:
204204
| x0, x1, x2, x3, x4, x5, x6, x7,
205205
x8, x9, x10, x11, x12, x13, x14, x15);
206206

207+
simd_ty!(f32x16[f32]:
208+
f32, f32, f32, f32, f32, f32, f32, f32,
209+
f32, f32, f32, f32, f32, f32, f32, f32
210+
| x0, x1, x2, x3, x4, x5, x6, x7,
211+
x8, x9, x10, x11, x12, x13, x14, x15);
212+
207213
simd_ty!(i64x8[i64]:
208214
i64, i64, i64, i64, i64, i64, i64, i64
209215
| x0, x1, x2, x3, x4, x5, x6, x7);

crates/core_arch/src/x86/avx512f.rs

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ pub unsafe fn _mm512_setzero_pd() -> __m512d {
5959
mem::zeroed()
6060
}
6161

62+
/// Returns vector of type `__m512d` with all elements set to zero.
63+
///
64+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#avx512techs=AVX512F&expand=33,34,4990&text=_mm512_setzero_pd)
65+
#[inline]
66+
#[target_feature(enable = "avx512f")]
67+
#[cfg_attr(test, assert_instr(vxorps))]
68+
pub unsafe fn _mm512_setzero_ps() -> __m512 {
69+
// All-0 is a properly initialized __m512
70+
mem::zeroed()
71+
}
72+
6273
/// Returns vector of type `__m512i` with all elements set to zero.
6374
///
6475
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#avx512techs=AVX512F&expand=33,34,4990&text=_mm512_setzero_si512)
@@ -239,6 +250,101 @@ pub unsafe fn _mm512_mask_i64gather_ps(
239250
transmute(r)
240251
}
241252

253+
/// Gather single-precision (32-bit) floating-point elements from memory using 32-bit indices.
254+
///
255+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_i32gather_ps)
256+
#[inline]
257+
#[target_feature(enable = "avx512f")]
258+
#[cfg_attr(test, assert_instr(vgatherdps, scale = 1))]
259+
#[rustc_args_required_const(2)]
260+
pub unsafe fn _mm512_i32gather_ps(offsets: __m512i, slice: *const u8, scale: i32) -> __m512 {
261+
let zero = _mm512_setzero_ps().as_f32x16();
262+
let neg_one = -1;
263+
let slice = slice as *const i8;
264+
let offsets = offsets.as_i32x16();
265+
macro_rules! call {
266+
($imm8:expr) => {
267+
vgatherdps(zero, slice, offsets, neg_one, $imm8)
268+
};
269+
}
270+
let r = constify_imm8_gather!(scale, call);
271+
transmute(r)
272+
}
273+
274+
/// Gather single-precision (32-bit) floating-point elements from memory using 32-bit indices.
275+
///
276+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_mask_i32gather_ps)
277+
#[inline]
278+
#[target_feature(enable = "avx512f")]
279+
#[cfg_attr(test, assert_instr(vgatherdps, scale = 1))]
280+
#[rustc_args_required_const(4)]
281+
pub unsafe fn _mm512_mask_i32gather_ps(
282+
src: __m512,
283+
mask: __mmask16,
284+
offsets: __m512i,
285+
slice: *const u8,
286+
scale: i32,
287+
) -> __m512 {
288+
let src = src.as_f32x16();
289+
let slice = slice as *const i8;
290+
let offsets = offsets.as_i32x16();
291+
macro_rules! call {
292+
($imm8:expr) => {
293+
vgatherdps(src, slice, offsets, mask as i16, $imm8)
294+
};
295+
}
296+
let r = constify_imm8_gather!(scale, call);
297+
transmute(r)
298+
}
299+
300+
/// Gather 32-bit integers from memory using 32-bit indices.
301+
///
302+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_i32gather_epi32)
303+
#[inline]
304+
#[target_feature(enable = "avx512f")]
305+
#[cfg_attr(test, assert_instr(vpgatherdd, scale = 1))]
306+
#[rustc_args_required_const(2)]
307+
pub unsafe fn _mm512_i32gather_epi32(offsets: __m512i, slice: *const u8, scale: i32) -> __m512i {
308+
let zero = _mm512_setzero_si512().as_i32x16();
309+
let neg_one = -1;
310+
let slice = slice as *const i8;
311+
let offsets = offsets.as_i32x16();
312+
macro_rules! call {
313+
($imm8:expr) => {
314+
vpgatherdd(zero, slice, offsets, neg_one, $imm8)
315+
};
316+
}
317+
let r = constify_imm8_gather!(scale, call);
318+
transmute(r)
319+
}
320+
321+
/// Gather 32-bit integers from memory using 32-bit indices.
322+
///
323+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_mask_i32gather_epi32)
324+
#[inline]
325+
#[target_feature(enable = "avx512f")]
326+
#[cfg_attr(test, assert_instr(vpgatherdd, scale = 1))]
327+
#[rustc_args_required_const(4)]
328+
pub unsafe fn _mm512_mask_i32gather_epi32(
329+
src: __m512i,
330+
mask: __mmask16,
331+
offsets: __m512i,
332+
slice: *const u8,
333+
scale: i32,
334+
) -> __m512i {
335+
let src = src.as_i32x16();
336+
let mask = mask as i16;
337+
let slice = slice as *const i8;
338+
let offsets = offsets.as_i32x16();
339+
macro_rules! call {
340+
($imm8:expr) => {
341+
vpgatherdd(src, slice, offsets, mask, $imm8)
342+
};
343+
}
344+
let r = constify_imm8!(scale, call);
345+
transmute(r)
346+
}
347+
242348
/// Gather 64-bit integers from memory using 32-bit indices.
243349
///
244350
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_i32gather_epi64)
@@ -383,13 +489,78 @@ pub unsafe fn _mm512_mask_i64gather_epi32(
383489
transmute(r)
384490
}
385491

492+
/// Sets packed 32-bit integers in `dst` with the supplied values.
493+
///
494+
/// [Intel's documentation]( https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=727,1063,4909,1062,1062,4909&text=_mm512_set_ps)
495+
#[inline]
496+
#[target_feature(enable = "avx512f")]
497+
pub unsafe fn _mm512_set_ps(
498+
e0: f32,
499+
e1: f32,
500+
e2: f32,
501+
e3: f32,
502+
e4: f32,
503+
e5: f32,
504+
e6: f32,
505+
e7: f32,
506+
e8: f32,
507+
e9: f32,
508+
e10: f32,
509+
e11: f32,
510+
e12: f32,
511+
e13: f32,
512+
e14: f32,
513+
e15: f32,
514+
) -> __m512 {
515+
_mm512_setr_ps(
516+
e15, e14, e13, e12, e11, e10, e9, e8, e7, e6, e5, e4, e3, e2, e1, e0,
517+
)
518+
}
519+
520+
/// Sets packed 32-bit integers in `dst` with the supplied values in
521+
/// reverse order.
522+
///
523+
/// [Intel's documentation]( https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=727,1063,4909,1062,1062,4909&text=_mm512_set_ps)
524+
#[inline]
525+
#[target_feature(enable = "avx512f")]
526+
pub unsafe fn _mm512_setr_ps(
527+
e0: f32,
528+
e1: f32,
529+
e2: f32,
530+
e3: f32,
531+
e4: f32,
532+
e5: f32,
533+
e6: f32,
534+
e7: f32,
535+
e8: f32,
536+
e9: f32,
537+
e10: f32,
538+
e11: f32,
539+
e12: f32,
540+
e13: f32,
541+
e14: f32,
542+
e15: f32,
543+
) -> __m512 {
544+
let r = f32x16::new(
545+
e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, e12, e13, e14, e15,
546+
);
547+
transmute(r)
548+
}
549+
386550
/// Broadcast 64-bit float `a` to all elements of `dst`.
387551
#[inline]
388552
#[target_feature(enable = "avx512f")]
389553
pub unsafe fn _mm512_set1_pd(a: f64) -> __m512d {
390554
transmute(f64x8::splat(a))
391555
}
392556

557+
/// Broadcast 32-bit float `a` to all elements of `dst`.
558+
#[inline]
559+
#[target_feature(enable = "avx512f")]
560+
pub unsafe fn _mm512_set1_ps(a: f32) -> __m512 {
561+
transmute(f32x16::splat(a))
562+
}
563+
393564
/// Sets packed 32-bit integers in `dst` with the supplied values.
394565
#[inline]
395566
#[target_feature(enable = "avx512f")]
@@ -1119,12 +1290,16 @@ pub const _MM_CMPINT_TRUE: _MM_CMPINT_ENUM = 0x07;
11191290
extern "C" {
11201291
#[link_name = "llvm.x86.avx512.gather.dpd.512"]
11211292
fn vgatherdpd(src: f64x8, slice: *const i8, offsets: i32x8, mask: i8, scale: i32) -> f64x8;
1293+
#[link_name = "llvm.x86.avx512.gather.dps.512"]
1294+
fn vgatherdps(src: f32x16, slice: *const i8, offsets: i32x16, mask: i16, scale: i32) -> f32x16;
11221295
#[link_name = "llvm.x86.avx512.gather.qpd.512"]
11231296
fn vgatherqpd(src: f64x8, slice: *const i8, offsets: i64x8, mask: i8, scale: i32) -> f64x8;
11241297
#[link_name = "llvm.x86.avx512.gather.qps.512"]
11251298
fn vgatherqps(src: f32x8, slice: *const i8, offsets: i64x8, mask: i8, scale: i32) -> f32x8;
11261299
#[link_name = "llvm.x86.avx512.gather.dpq.512"]
11271300
fn vpgatherdq(src: i64x8, slice: *const i8, offsets: i32x8, mask: i8, scale: i32) -> i64x8;
1301+
#[link_name = "llvm.x86.avx512.gather.dpi.512"]
1302+
fn vpgatherdd(src: i32x16, slice: *const i8, offsets: i32x16, mask: i16, scale: i32) -> i32x16;
11281303
#[link_name = "llvm.x86.avx512.gather.qpq.512"]
11291304
fn vpgatherqq(src: i64x8, slice: *const i8, offsets: i64x8, mask: i8, scale: i32) -> i64x8;
11301305
#[link_name = "llvm.x86.avx512.gather.qpi.512"]
@@ -1244,6 +1419,74 @@ mod tests {
12441419
assert_eq_m512i(r, e);
12451420
}
12461421

1422+
#[simd_test(enable = "avx512f")]
1423+
unsafe fn test_mm512_i32gather_ps() {
1424+
let mut arr = [0f32; 256];
1425+
for i in 0..256 {
1426+
arr[i] = i as f32;
1427+
}
1428+
// A multiplier of 4 is word-addressing
1429+
#[rustfmt::skip]
1430+
let index = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112,
1431+
120, 128, 136, 144, 152, 160, 168, 176);
1432+
let r = _mm512_i32gather_ps(index, arr.as_ptr() as *const u8, 4);
1433+
#[rustfmt::skip]
1434+
assert_eq_m512(r, _mm512_setr_ps(0., 16., 32., 48., 64., 80., 96., 112.,
1435+
120., 128., 136., 144., 152., 160., 168., 176.));
1436+
}
1437+
1438+
#[simd_test(enable = "avx512f")]
1439+
unsafe fn test_mm512_mask_i32gather_ps() {
1440+
let mut arr = [0f32; 256];
1441+
for i in 0..256 {
1442+
arr[i] = i as f32;
1443+
}
1444+
let src = _mm512_set1_ps(2.);
1445+
let mask = 0b10101010_10101010;
1446+
#[rustfmt::skip]
1447+
let index = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112,
1448+
120, 128, 136, 144, 152, 160, 168, 176);
1449+
// A multiplier of 4 is word-addressing
1450+
let r = _mm512_mask_i32gather_ps(src, mask, index, arr.as_ptr() as *const u8, 4);
1451+
#[rustfmt::skip]
1452+
assert_eq_m512(r, _mm512_setr_ps(2., 16., 2., 48., 2., 80., 2., 112.,
1453+
2., 128., 2., 144., 2., 160., 2., 176.));
1454+
}
1455+
1456+
#[simd_test(enable = "avx512f")]
1457+
unsafe fn test_mm512_i32gather_epi32() {
1458+
let mut arr = [0i32; 256];
1459+
for i in 0..256 {
1460+
arr[i] = i as i32;
1461+
}
1462+
// A multiplier of 4 is word-addressing
1463+
#[rustfmt::skip]
1464+
let index = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112,
1465+
120, 128, 136, 144, 152, 160, 168, 176);
1466+
let r = _mm512_i32gather_epi32(index, arr.as_ptr() as *const u8, 4);
1467+
#[rustfmt::skip]
1468+
assert_eq_m512i(r, _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112,
1469+
120, 128, 136, 144, 152, 160, 168, 176));
1470+
}
1471+
1472+
#[simd_test(enable = "avx512f")]
1473+
unsafe fn test_mm512_mask_i32gather_epi32() {
1474+
let mut arr = [0i32; 256];
1475+
for i in 0..256 {
1476+
arr[i] = i as i32;
1477+
}
1478+
let src = _mm512_set1_epi32(2);
1479+
let mask = 0b10101010_10101010;
1480+
#[rustfmt::skip]
1481+
let index = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112,
1482+
120, 128, 136, 144, 152, 160, 168, 176);
1483+
// A multiplier of 4 is word-addressing
1484+
let r = _mm512_mask_i32gather_epi32(src, mask, index, arr.as_ptr() as *const u8, 4);
1485+
#[rustfmt::skip]
1486+
assert_eq_m512i(r, _mm512_setr_epi32(2, 16, 2, 48, 2, 80, 2, 112,
1487+
2, 128, 2, 144, 2, 160, 2, 176));
1488+
}
1489+
12471490
#[simd_test(enable = "avx512f")]
12481491
unsafe fn test_mm512_cmplt_epu32_mask() {
12491492
#[rustfmt::skip]
@@ -1586,4 +1829,43 @@ mod tests {
15861829
unsafe fn test_mm512_setzero_si512() {
15871830
assert_eq_m512i(_mm512_set1_epi32(0), _mm512_setzero_si512());
15881831
}
1832+
1833+
#[simd_test(enable = "avx512f")]
1834+
unsafe fn test_mm512_set_ps() {
1835+
let r = _mm512_setr_ps(
1836+
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,
1837+
);
1838+
assert_eq_m512(
1839+
r,
1840+
_mm512_set_ps(
1841+
15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 0.,
1842+
),
1843+
)
1844+
}
1845+
1846+
#[simd_test(enable = "avx512f")]
1847+
unsafe fn test_mm512_setr_ps() {
1848+
let r = _mm512_set_ps(
1849+
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,
1850+
);
1851+
assert_eq_m512(
1852+
r,
1853+
_mm512_setr_ps(
1854+
15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 0.,
1855+
),
1856+
)
1857+
}
1858+
1859+
#[simd_test(enable = "avx512f")]
1860+
unsafe fn test_mm512_set1_ps() {
1861+
#[rustfmt::skip]
1862+
let expected = _mm512_set_ps(2., 2., 2., 2., 2., 2., 2., 2.,
1863+
2., 2., 2., 2., 2., 2., 2., 2.);
1864+
assert_eq_m512(expected, _mm512_set1_ps(2.));
1865+
}
1866+
1867+
#[simd_test(enable = "avx512f")]
1868+
unsafe fn test_mm512_setzero_ps() {
1869+
assert_eq_m512(_mm512_setzero_ps(), _mm512_set1_ps(0.));
1870+
}
15891871
}

crates/core_arch/src/x86/mod.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,24 @@ impl m512iExt for __m512i {
559559
}
560560
}
561561

562+
#[allow(non_camel_case_types)]
563+
#[unstable(feature = "stdimd_internal", issue = "none")]
564+
pub(crate) trait m512Ext: Sized {
565+
fn as_m512(self) -> __m512;
566+
567+
#[inline]
568+
fn as_f32x16(self) -> crate::core_arch::simd::f32x16 {
569+
unsafe { transmute(self.as_m512()) }
570+
}
571+
}
572+
573+
impl m512Ext for __m512 {
574+
#[inline]
575+
fn as_m512(self) -> Self {
576+
self
577+
}
578+
}
579+
562580
#[allow(non_camel_case_types)]
563581
#[unstable(feature = "stdimd_internal", issue = "none")]
564582
pub(crate) trait m512dExt: Sized {

crates/core_arch/src/x86/test.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,17 @@ pub unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) {
144144
assert_eq!(A { a }.b, A { a: b }.b)
145145
}
146146

147+
pub unsafe fn assert_eq_m512(a: __m512, b: __m512) {
148+
// TODO: This should use `_mm512_cmpeq_ps_mask`, but that isn't yet implemented.
149+
union A {
150+
a: __m512,
151+
b: [f32; 16],
152+
}
153+
assert_eq!(A { a }.b, A { a: b }.b)
154+
}
155+
147156
pub unsafe fn assert_eq_m512d(a: __m512d, b: __m512d) {
148-
// TODO: This should probably use `_mm512_cmpeq_pd_mask`, but that requires KNC.
157+
// TODO: This should use `_mm512_cmpeq_pd_mask`, but that isn't yet implemented.
149158
union A {
150159
a: __m512d,
151160
b: [f64; 8],

0 commit comments

Comments
 (0)