Skip to content

Commit e3305ba

Browse files
committed
Add simd_masked_store platform-intrinsic
This maps to llvm.masked.store
1 parent 6bf4273 commit e3305ba

File tree

3 files changed

+104
-1
lines changed

3 files changed

+104
-1
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14981498
// * T: type of the element to load
14991499
// * M: any integer width is supported, will be truncated to i1
15001500

1501-
// The first argument is
1501+
// The first argument is a passthrough vector providing values for disabled lanes
15021502
let (_, element_ty0) = require_simd!(in_ty, SimdFirst);
15031503

15041504
// The element type of the second argument must be a signed integer type of any width:
@@ -1586,6 +1586,107 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15861586
return Ok(v);
15871587
}
15881588

1589+
if name == sym::simd_masked_store {
1590+
// simd_masked_store(values: <N x T>, pointer: *mut T,
1591+
// mask: <N x i{M}>) -> ()
1592+
// * N: number of elements in the input vectors
1593+
// * T: type of the element to load
1594+
// * M: any integer width is supported, will be truncated to i1
1595+
1596+
// The first argument is a passthrough vector providing values for disabled lanes
1597+
let (element_len1, element_ty0) = require_simd!(in_ty, SimdFirst);
1598+
1599+
// The second argument must be a simd vector with an element type that's a pointer
1600+
// to the element type of the first argument
1601+
let (mask_len, element_ty2) = require_simd!(arg_tys[2], SimdSecond);
1602+
1603+
// Of the same length:
1604+
require!(
1605+
in_len == element_len1,
1606+
InvalidMonomorphization::SecondArgumentLength {
1607+
span,
1608+
name,
1609+
in_len,
1610+
in_ty,
1611+
arg_ty: arg_tys[1],
1612+
out_len: element_len1
1613+
}
1614+
);
1615+
require!(
1616+
in_len == mask_len,
1617+
InvalidMonomorphization::ThirdArgumentLength {
1618+
span,
1619+
name,
1620+
in_len,
1621+
in_ty,
1622+
arg_ty: arg_tys[2],
1623+
out_len: mask_len
1624+
}
1625+
);
1626+
1627+
// Pointer type must match the element type
1628+
require!(
1629+
matches!(
1630+
arg_tys[1].kind(),
1631+
ty::RawPtr(p) if p.ty == in_elem && p.ty.kind() == element_ty0.kind() && p.mutbl.is_mut()
1632+
),
1633+
InvalidMonomorphization::ExpectedElementType {
1634+
span,
1635+
name,
1636+
expected_element: in_elem,
1637+
second_arg: arg_tys[1],
1638+
in_elem,
1639+
in_ty,
1640+
mutability: ExpectedPointerMutability::Mut,
1641+
}
1642+
);
1643+
1644+
// The element type of the third argument must be a signed integer type of any width:
1645+
match element_ty2.kind() {
1646+
ty::Int(_) => (),
1647+
_ => {
1648+
return_error!(InvalidMonomorphization::ThirdArgElementType {
1649+
span,
1650+
name,
1651+
expected_element: element_ty2,
1652+
third_arg: arg_tys[2]
1653+
});
1654+
}
1655+
}
1656+
1657+
// Alignment of T, must be a constant integer value:
1658+
let alignment_ty = bx.type_i32();
1659+
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
1660+
1661+
// Truncate the mask vector to a vector of i1s:
1662+
let (mask, mask_ty) = {
1663+
let i1 = bx.type_i1();
1664+
let i1xn = bx.type_vector(i1, in_len);
1665+
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1666+
};
1667+
1668+
let ret_t = bx.type_void();
1669+
1670+
let llvm_pointer = bx.type_ptr();
1671+
1672+
// Type of the vector of elements:
1673+
let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);
1674+
let llvm_elem_vec_str = llvm_vector_str(bx, in_elem, in_len);
1675+
1676+
let llvm_intrinsic = format!("llvm.masked.store.{llvm_elem_vec_str}.p0");
1677+
let fn_ty = bx.type_func(&[llvm_elem_vec_ty, llvm_pointer, alignment_ty, mask_ty], ret_t);
1678+
let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
1679+
let v = bx.call(
1680+
fn_ty,
1681+
None,
1682+
None,
1683+
f,
1684+
&[args[0].immediate(), args[1].immediate(), alignment, mask],
1685+
None,
1686+
);
1687+
return Ok(v);
1688+
}
1689+
15891690
if name == sym::simd_scatter {
15901691
// simd_scatter(values: <N x T>, pointers: <N x *mut T>,
15911692
// mask: <N x i{M}>) -> ()

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
541541
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
542542
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
543543
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(0)),
544+
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
544545
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
545546
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
546547
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,7 @@ symbols! {
15331533
simd_le,
15341534
simd_lt,
15351535
simd_masked_load,
1536+
simd_masked_store,
15361537
simd_mul,
15371538
simd_ne,
15381539
simd_neg,

0 commit comments

Comments
 (0)