Skip to content

Commit 07bd4e1

Browse files
committed
Add bypass for x86amx
1 parent 965fc68 commit 07bd4e1

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,25 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
361361
}
362362

363363
match self.type_kind(llvm_ty) {
364+
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
365+
let element_count = self.vector_length(rust_ty);
366+
let element_ty = self.element_type(rust_ty);
367+
368+
let element_size_bits = match self.type_kind(element_ty) {
369+
TypeKind::Half => 16,
370+
TypeKind::Float => 32,
371+
TypeKind::Double => 64,
372+
TypeKind::FP128 => 128,
373+
TypeKind::Integer => self.int_width(element_ty),
374+
TypeKind::Pointer => self.int_width(self.isize_ty()),
375+
_ => bug!(
376+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
377+
),
378+
};
379+
let vector_size_bits = element_size_bits * element_count as u64;
380+
381+
vector_size_bits == 8192
382+
}
364383
TypeKind::BFloat => rust_ty == self.type_i16(),
365384
TypeKind::Vector => {
366385
let llvm_element_count = self.vector_length(llvm_ty) as u64;

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17311731
}
17321732

17331733
match self.type_kind(llvm_ty) {
1734+
TypeKind::X86_AMX => {
1735+
let vector_length = self.vector_length(rust_ty);
1736+
let element_ty = self.element_type(rust_ty);
1737+
let element_ty_str = match self.type_kind(element_ty) {
1738+
TypeKind::Half => "f16",
1739+
TypeKind::Float => "f32",
1740+
TypeKind::Double => "f64",
1741+
TypeKind::FP128 => "f128",
1742+
TypeKind::Integer => &format!("i{}", self.int_width(element_ty)),
1743+
TypeKind::Pointer => "p0",
1744+
_ => bug!(
1745+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
1746+
),
1747+
};
1748+
1749+
let base_name = if is_argument {
1750+
"llvm.x86.cast.vector.to.tile"
1751+
} else {
1752+
"llvm.x86.cast.tile.to.vector"
1753+
};
1754+
let llvm_intrinsic = format!("{base_name}.v{vector_length}{element_ty_str}");
1755+
let fn_ty = self.type_func(&[src_ty], dest_ty);
1756+
let f = self.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
1757+
self.call(fn_ty, None, None, f, &[val], None, None)
1758+
}
17341759
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
17351760
if is_argument {
17361761
self.trunc_int_to_i1_vector(val, dest_ty)

0 commit comments

Comments
 (0)