Skip to content

Add autocast support for x86amx #142251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
fn checked_binop(
&mut self,
oop: OverflowOp,
typ: Ty<'_>,
typ: Ty<'tcx>,
lhs: Self::Value,
rhs: Self::Value,
) -> (Self::Value, Self::Value) {
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_codegen_gcc/src/type_of.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Write;

use gccjit::{Struct, Type};
use gccjit::{RValue, Struct, Type};
use rustc_abi as abi;
use rustc_abi::Primitive::*;
use rustc_abi::{
Expand Down Expand Up @@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
unimplemented!();
}

fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
fn fn_decl_backend_type(
&self,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_fn_ptr: RValue<'gcc>,
) -> Type<'gcc> {
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)
Expand Down
223 changes: 201 additions & 22 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::borrow::Borrow;
use std::cmp;
use std::{cmp, iter};

use libc::c_uint;
use rustc_abi::{
ArmCall, BackendRepr, CanonAbi, HasDataLayout, InterruptKind, Primitive, Reg, RegKind, Size,
X86Call,
};
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
Expand All @@ -22,7 +23,7 @@ use smallvec::SmallVec;

use crate::attributes::{self, llfn_attrs_from_instance};
use crate::builder::Builder;
use crate::context::CodegenCx;
use crate::context::{CodegenCx, GenericCx, SCx};
use crate::llvm::{self, Attribute, AttributePlace};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
Expand Down Expand Up @@ -300,8 +301,39 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}
}

pub(crate) enum FunctionSignature<'ll> {
/// The signature is obtained directly from LLVM, and **may not match the Rust signature**
Intrinsic(&'ll Type),
/// The name starts with `llvm.`, but can't obtain the intrinsic ID. May be invalid or upgradable
MaybeInvalidIntrinsic(&'ll Type),
/// Just the Rust signature
Rust(&'ll Type),
}

impl<'ll> FunctionSignature<'ll> {
pub(crate) fn fn_ty(&self) -> &'ll Type {
match self {
FunctionSignature::Intrinsic(fn_ty)
| FunctionSignature::MaybeInvalidIntrinsic(fn_ty)
| FunctionSignature::Rust(fn_ty) => fn_ty,
}
}
}

pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_argument_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type>;
/// When `do_verify` is set, this function performs checks for the signature of LLVM intrinsics
/// and emits a fatal error if it doesn't match. These checks are important,but somewhat expensive
/// So they are only used at function definitions, not at callsites
fn llvm_type(
&self,
cx: &CodegenCx<'ll, 'tcx>,
name: &[u8],
do_verify: bool,
) -> FunctionSignature<'ll>;
/// **If this function is an LLVM intrinsic** checks if the LLVM signature provided matches with this
fn verify_intrinsic_signature(&self, cx: &CodegenCx<'ll, 'tcx>, llvm_ty: &'ll Type) -> bool;
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;

Expand All @@ -314,30 +346,97 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
);

/// Apply attributes to a function call.
fn apply_attrs_callsite(&self, bx: &mut Builder<'_, 'll, 'tcx>, callsite: &'ll Value);
fn apply_attrs_callsite(
&self,
bx: &mut Builder<'_, 'll, 'tcx>,
callsite: &'ll Value,
llfn: &'ll Value,
);
}

impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
pub(crate) fn equate_ty(&self, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
if rust_ty == llvm_ty {
return true;
}

match self.type_kind(llvm_ty) {
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
let element_count = self.vector_length(rust_ty);
let element_ty = self.element_type(rust_ty);

let element_size_bits = match self.type_kind(element_ty) {
TypeKind::Half => 16,
TypeKind::Float => 32,
TypeKind::Double => 64,
TypeKind::FP128 => 128,
TypeKind::Integer => self.int_width(element_ty),
TypeKind::Pointer => self.int_width(self.isize_ty()),
_ => bug!(
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
),
};
let vector_size_bits = element_size_bits * element_count as u64;

vector_size_bits == 8192
}
TypeKind::BFloat => rust_ty == self.type_i16(),
TypeKind::Vector => {
let llvm_element_count = self.vector_length(llvm_ty) as u64;
let llvm_element_ty = self.element_type(llvm_ty);

if llvm_element_ty == self.type_bf16() {
rust_ty == self.type_vector(self.type_i16(), llvm_element_count)
} else if llvm_element_ty == self.type_i1() {
let int_width = cmp::max(llvm_element_count.next_power_of_two(), 8);
rust_ty == self.type_ix(int_width)
} else {
false
}
}
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
let rust_element_tys = self.struct_element_types(rust_ty);
let llvm_element_tys = self.struct_element_types(llvm_ty);

if rust_element_tys.len() != llvm_element_tys.len() {
return false;
}

iter::zip(rust_element_tys, llvm_element_tys).all(
|(rust_element_ty, llvm_element_ty)| {
self.equate_ty(rust_element_ty, llvm_element_ty)
},
)
}
_ => false,
}
}
}

impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => cx.type_void(),
}
}

fn llvm_argument_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type> {
let indirect_return = matches!(self.ret.mode, PassMode::Indirect { .. });

// Ignore "extra" args from the call site for C variadic functions.
// Only the "fixed" args are part of the LLVM function signature.
let args =
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };

// This capacity calculation is approximate.
let mut llargument_tys = Vec::with_capacity(
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
);
let mut llargument_tys =
Vec::with_capacity(args.len() + if indirect_return { 1 } else { 0 });

let llreturn_ty = match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => {
llargument_tys.push(cx.type_ptr());
cx.type_void()
}
};
if indirect_return {
llargument_tys.push(cx.type_ptr());
}

for arg in args {
// Note that the exact number of arguments pushed here is carefully synchronized with
Expand Down Expand Up @@ -384,10 +483,74 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
llargument_tys.push(llarg_ty);
}

if self.c_variadic {
cx.type_variadic_func(&llargument_tys, llreturn_ty)
llargument_tys
}

fn verify_intrinsic_signature(&self, cx: &CodegenCx<'ll, 'tcx>, llvm_fn_ty: &'ll Type) -> bool {
let rust_return_ty = self.llvm_return_type(cx);
let rust_argument_tys = self.llvm_argument_types(cx);

let llvm_return_ty = cx.get_return_type(llvm_fn_ty);
let llvm_argument_tys = cx.func_params_types(llvm_fn_ty);
let llvm_is_variadic = cx.func_is_variadic(llvm_fn_ty);

if self.c_variadic != llvm_is_variadic || rust_argument_tys.len() != llvm_argument_tys.len()
{
return false;
}

// todo: add bypasses for types not accessible from Rust here
iter::once((rust_return_ty, llvm_return_ty))
.chain(iter::zip(rust_argument_tys, llvm_argument_tys))
.all(|(rust_ty, llvm_ty)| cx.equate_ty(rust_ty, llvm_ty))
}

fn llvm_type(
&self,
cx: &CodegenCx<'ll, 'tcx>,
name: &[u8],
do_verify: bool,
) -> FunctionSignature<'ll> {
let mut maybe_invalid = false;

if name.starts_with(b"llvm.") {
if let Some(intrinsic) = llvm::Intrinsic::lookup(name) {
if !intrinsic.is_overloaded() {
// FIXME: also do this for overloaded intrinsics
let llvm_fn_ty = cx.intrinsic_type(intrinsic, &[]);
if do_verify {
if !self.verify_intrinsic_signature(cx, llvm_fn_ty) {
cx.tcx.dcx().fatal(format!(
"Intrinsic signature mismatch for `{}`: expected signature `{llvm_fn_ty:?}`",
str::from_utf8(name).unwrap()
));
}
}
return FunctionSignature::Intrinsic(llvm_fn_ty);
}
} else {
// it's one of 2 cases,
// - either the base name is invalid
// - it has been superceded by something else, so the intrinsic was removed entirely
// to check for upgrades, we need the `llfn`, so we defer it for now

maybe_invalid = true;
}
}

let return_ty = self.llvm_return_type(cx);
let argument_tys = self.llvm_argument_types(cx);

let fn_ty = if self.c_variadic {
cx.type_variadic_func(&argument_tys, return_ty)
} else {
cx.type_func(&llargument_tys, llreturn_ty)
cx.type_func(&argument_tys, return_ty)
};

if maybe_invalid {
FunctionSignature::MaybeInvalidIntrinsic(fn_ty)
} else {
FunctionSignature::Rust(fn_ty)
}
}

Expand Down Expand Up @@ -530,7 +693,23 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
}
}

fn apply_attrs_callsite(&self, bx: &mut Builder<'_, 'll, 'tcx>, callsite: &'ll Value) {
fn apply_attrs_callsite(
&self,
bx: &mut Builder<'_, 'll, 'tcx>,
callsite: &'ll Value,
llfn: &'ll Value,
) {
// if we are using the LLVM signature, use the LLVM attributes otherwise it might be problematic
let name = llvm::get_value_name(llfn);
if name.starts_with(b"llvm.")
&& let Some(intrinsic) = llvm::Intrinsic::lookup(name)
{
// FIXME: also do this for overloaded intrinsics
if !intrinsic.is_overloaded() {
return;
}
}

let mut func_attrs = SmallVec::<[_; 2]>::new();
if self.ret.layout.is_uninhabited() {
func_attrs.push(llvm::AttributeKind::NoReturn.create_attr(bx.cx.llcx));
Expand Down
Loading
Loading