Skip to content

Commit 20e8c4c

Browse files
committed
Add simd_load platform-intrinsic
This maps to a masked vector load - llvm.masked.load
1 parent 0b24479 commit 20e8c4c

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,100 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14921492
return Ok(v);
14931493
}
14941494

1495+
if name == sym::simd_load {
1496+
// simd_load(values: <N x T>, pointer: *_ T, mask: <N x i{M}>) -> <N x T>
1497+
// * N: number of elements in the input vectors
1498+
// * T: type of the element to load
1499+
// * M: any integer width is supported, will be truncated to i1
1500+
1501+
// The first argument is
1502+
let (_, element_ty0) = require_simd!(in_ty, SimdFirst);
1503+
1504+
// The element type of the second argument must be a signed integer type of any width:
1505+
let (mask_len, element_ty2) = require_simd!(arg_tys[2], SimdSecond);
1506+
require_simd!(ret_ty, SimdReturn);
1507+
1508+
// Of the same length:
1509+
require!(
1510+
in_len == mask_len,
1511+
InvalidMonomorphization::ThirdArgumentLength {
1512+
span,
1513+
name,
1514+
in_len,
1515+
in_ty,
1516+
arg_ty: arg_tys[2],
1517+
out_len: mask_len
1518+
}
1519+
);
1520+
1521+
// The return type must match the first argument type
1522+
require!(
1523+
ret_ty == in_ty,
1524+
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty, ret_ty }
1525+
);
1526+
1527+
// Pointer type must match the element type
1528+
require!(
1529+
matches!(
1530+
arg_tys[1].kind(),
1531+
ty::RawPtr(p) if p.ty == in_elem && p.ty.kind() == element_ty0.kind()
1532+
),
1533+
InvalidMonomorphization::ExpectedElementType {
1534+
span,
1535+
name,
1536+
expected_element: in_elem,
1537+
second_arg: arg_tys[1],
1538+
in_elem,
1539+
in_ty,
1540+
mutability: ExpectedPointerMutability::Not,
1541+
}
1542+
);
1543+
1544+
// Mask needs to be an integer type
1545+
match element_ty2.kind() {
1546+
ty::Int(_) => (),
1547+
_ => {
1548+
return_error!(InvalidMonomorphization::ThirdArgElementType {
1549+
span,
1550+
name,
1551+
expected_element: element_ty2,
1552+
third_arg: arg_tys[2]
1553+
});
1554+
}
1555+
}
1556+
1557+
// Alignment of T, must be a constant integer value:
1558+
let alignment_ty = bx.type_i32();
1559+
let alignment = bx.const_i32(bx.align_of(in_ty).bytes() as i32);
1560+
1561+
// Truncate the mask vector to a vector of i1s:
1562+
let (mask, mask_ty) = {
1563+
let i1 = bx.type_i1();
1564+
let i1xn = bx.type_vector(i1, mask_len);
1565+
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1566+
};
1567+
1568+
let llvm_pointer = bx.type_ptr();
1569+
1570+
// Type of the vector of elements:
1571+
let llvm_elem_vec_ty = llvm_vector_ty(bx, in_elem, mask_len);
1572+
let llvm_elem_vec_str = llvm_vector_str(bx, in_elem, mask_len);
1573+
1574+
let llvm_intrinsic = format!("llvm.masked.load.{llvm_elem_vec_str}.p0");
1575+
let fn_ty = bx
1576+
.type_func(&[llvm_pointer, alignment_ty, mask_ty, llvm_elem_vec_ty], llvm_elem_vec_ty);
1577+
let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
1578+
let v = bx.call(
1579+
fn_ty,
1580+
None,
1581+
None,
1582+
f,
1583+
&[args[1].immediate(), alignment, mask, args[0].immediate()],
1584+
None,
1585+
);
1586+
return Ok(v);
1587+
}
1588+
14951589
if name == sym::simd_scatter {
14961590
// simd_scatter(values: <N x T>, pointers: <N x *mut T>,
14971591
// mask: <N x i{M}>) -> ()

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
540540
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
541541
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
542542
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
543+
sym::simd_load => (3, 0, vec![param(0), param(1), param(2)], param(0)),
543544
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
544545
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
545546
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,7 @@ symbols! {
15311531
simd_gt,
15321532
simd_insert,
15331533
simd_le,
1534+
simd_load,
15341535
simd_lt,
15351536
simd_mul,
15361537
simd_ne,
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//
2+
3+
// compile-flags: -C no-prepopulate-passes
4+
5+
#![crate_type = "lib"]
6+
7+
#![feature(repr_simd, platform_intrinsics)]
8+
#![allow(non_camel_case_types)]
9+
10+
#[repr(simd)]
11+
#[derive(Copy, Clone, PartialEq, Debug)]
12+
pub struct Vec2<T>(pub T, pub T);
13+
14+
#[repr(simd)]
15+
#[derive(Copy, Clone, PartialEq, Debug)]
16+
pub struct Vec4<T>(pub T, pub T, pub T, pub T);
17+
18+
extern "platform-intrinsic" {
19+
fn simd_load<T, P, M>(values: T, pointer: P, mask: M) -> T;
20+
}
21+
22+
// CHECK-LABEL: @load_f32x2
23+
#[no_mangle]
24+
pub unsafe fn load_f32x2(pointer: *const f32, mask: Vec2<i32>,
25+
values: Vec2<f32>) -> Vec2<f32> {
26+
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x float> {{.*}})
27+
simd_load(values, pointer, mask)
28+
}
29+
30+
// CHECK-LABEL: @load_pf32x2
31+
#[no_mangle]
32+
pub unsafe fn load_pf32x2(pointer: *const *const f32, mask: Vec2<i32>,
33+
values: Vec2<*const f32>) -> Vec2<*const f32> {
34+
// CHECK: call <2 x ptr> @llvm.masked.load.v2p0.p0({{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x ptr> {{.*}})
35+
simd_load(values, pointer, mask)
36+
}

0 commit comments

Comments
 (0)