Skip to content

Commit 53998d0

Browse files
committed
Add bypass for x86amx
1 parent ade6ec1 commit 53998d0

File tree

5 files changed

+39
-0
lines changed

5 files changed

+39
-0
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,25 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
361361
}
362362

363363
match self.type_kind(llvm_ty) {
364+
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
365+
let element_count = self.vector_length(rust_ty);
366+
let element_ty = self.element_type(rust_ty);
367+
368+
let element_size_bits = match self.type_kind(element_ty) {
369+
TypeKind::Half => 16,
370+
TypeKind::Float => 32,
371+
TypeKind::Double => 64,
372+
TypeKind::FP128 => 128,
373+
TypeKind::Integer => self.int_width(element_ty),
374+
TypeKind::Pointer => self.int_width(self.isize_ty()),
375+
_ => bug!(
376+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
377+
),
378+
};
379+
let vector_size_bits = element_size_bits * element_count as u64;
380+
381+
vector_size_bits == 8192
382+
}
364383
TypeKind::BFloat => rust_ty == self.type_i16(),
365384
TypeKind::Vector => {
366385
let llvm_element_count = self.vector_length(llvm_ty) as u64;

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,15 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16501650
}
16511651

16521652
match self.type_kind(llvm_ty) {
1653+
TypeKind::X86_AMX => {
1654+
let base_name = if is_argument {
1655+
"llvm.x86.cast.vector.to.tile"
1656+
} else {
1657+
"llvm.x86.cast.tile.to.vector"
1658+
};
1659+
1660+
self.call_intrinsic(base_name, &[rust_ty], &[val])
1661+
}
16531662
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
16541663
if is_argument {
16551664
self.trunc_int_to_i1_vector(val, dest_ty)

compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ impl<'ll> CodegenCx<'ll, '_> {
927927
let t_isize = self.type_isize();
928928
let t_metadata = self.type_metadata();
929929
let t_token = self.type_token();
930+
let x86amx = self.type_x86amx();
930931

931932
ifn!("llvm.wasm.get.exception", fn(t_token) -> ptr);
932933
ifn!("llvm.wasm.get.ehselector", fn(t_token) -> t_i32);
@@ -1039,6 +1040,9 @@ impl<'ll> CodegenCx<'ll, '_> {
10391040
ifn!("llvm.masked.gather", fn(1, t_i32, same_width_vector(0, i1), 0) -> 0);
10401041
ifn!("llvm.masked.scatter", fn(0, 1, t_i32, same_width_vector(0, i1)) -> void);
10411042

1043+
ifn!("llvm.x86.cast.vector.to.tile", fn(0) -> x86amx);
1044+
ifn!("llvm.x86.cast.tile.to.vector", fn(x86amx) -> 0);
1045+
10421046
bug!("Unknown intrinsic: `{base_name}`")
10431047
}
10441048

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,9 @@ unsafe extern "C" {
10851085
pub(crate) fn LLVMTokenTypeInContext(C: &Context) -> &Type;
10861086
pub(crate) fn LLVMMetadataTypeInContext(C: &Context) -> &Type;
10871087

1088+
// X86-specific type for AMX
1089+
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;
1090+
10881091
// Operations on all values
10891092
pub(crate) fn LLVMTypeOf(Val: &Value) -> &Type;
10901093
pub(crate) fn LLVMGetValueName2(Val: &Value, Length: *mut size_t) -> *const c_char;

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
181181
pub(crate) fn type_bf16(&self) -> &'ll Type {
182182
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
183183
}
184+
185+
pub(crate) fn type_x86amx(&self) -> &'ll Type {
186+
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
187+
}
184188
}
185189

186190
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {

0 commit comments

Comments
 (0)