Skip to content

[CIR] Function type return type improvements #128787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,6 @@ def FuncOp : CIR_Op<"func", [
return getFunctionType().getReturnTypes();
}

/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
/// the 'type' attribute is present and checks if it holds a function type.
/// Ensures getType, getNumFuncArguments, and getNumFuncResults can be
/// called safely.
llvm::LogicalResult verifyType();

//===------------------------------------------------------------------===//
// SymbolOpInterface Methods
//===------------------------------------------------------------------===//
Expand Down
41 changes: 29 additions & 12 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -287,32 +287,44 @@ def CIR_BoolType :
def CIR_FuncType : CIR_Type<"Func", "func"> {
let summary = "CIR function type";
let description = [{
The `!cir.func` is a function type. It consists of a single return type, a
list of parameter types and can optionally be variadic.
The `!cir.func` is a function type. It consists of an optional return type,
a list of parameter types and can optionally be variadic.

Example:

```mlir
!cir.func<!bool ()>
!cir.func<!s32i (!s8i, !s8i)>
!cir.func<!s32i (!s32i, ...)>
!cir.func<()>
!cir.func<() -> bool>
!cir.func<(!s8i, !s8i)>
!cir.func<(!s8i, !s8i) -> !s32i>
!cir.func<(!s32i, ...) -> !s32i>
```
}];

let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
"mlir::Type":$returnType, "bool":$varArg);
OptionalParameter<"mlir::Type">:$optionalReturnType,
"bool":$varArg);
// Use a custom parser to handle argument types with variadic elipsis.
let assemblyFormat = [{
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
`<` custom<FuncTypeParams>($inputs, $varArg) (`->` $optionalReturnType^)? `>`
}];

let builders = [
// Create a FuncType, converting the return type from C-style to
// MLIR-style. If the given return type is `cir::VoidType`, ignore it
// and create the FuncType with no return type, which is how MLIR
// represents function types.
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
CArg<"bool", "false">:$isVarArg), [{
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
return $_get(returnType.getContext(), inputs,
mlir::isa<cir::VoidType>(returnType) ? nullptr : returnType,
isVarArg);
}]>
];

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Returns whether the function is variadic.
bool isVarArg() const { return getVarArg(); }
Expand All @@ -323,12 +335,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
/// Returns the number of arguments to the function.
unsigned getNumInputs() const { return getInputs().size(); }

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
/// Get the C-style return type of the function, which is !cir.void if the
/// function returns nothing and the actual return type otherwise.
mlir::Type getReturnType() const;

/// Get the MLIR-style return type of the function, which is an empty
/// ArrayRef if the function returns nothing and a single-element ArrayRef
/// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> getReturnTypes() const;

/// Returns whether the function is returns void.
bool isVoid() const;
/// Does the function type return nothing?
bool hasVoidReturn() const;

/// Returns a clone of this function type with the given argument
/// and result types.
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) {
mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) {
assert(qft.isCanonical());
const FunctionType *ft = cast<FunctionType>(qft.getTypePtr());
// First, check whether we can build the full fucntion type. If the function
// First, check whether we can build the full function type. If the function
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
// the function type.
if (!isFuncTypeConvertible(ft)) {
Expand Down
11 changes: 0 additions & 11 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,17 +416,6 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
}
}

// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
// attribute is present. This can check for preconditions of the
// getNumArguments hook not failing.
LogicalResult cir::FuncOp::verifyType() {
auto type = getFunctionType();
if (!isa<cir::FuncType>(type))
return emitOpError("requires '" + getFunctionTypeAttrName().str() +
"' attribute of function type");
return success();
}

// TODO(CIR): The properties of functions that require verification haven't
// been implemented yet.
mlir::LogicalResult cir::FuncOp::verify() { return success(); }
Expand Down
99 changes: 60 additions & 39 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncTypeParams(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg);

//===----------------------------------------------------------------------===//
// Get autogenerated stuff
Expand Down Expand Up @@ -282,40 +283,32 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}

mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
// Custom parser that parses function parameters of form `(<type>*, ...)`.
static mlir::ParseResult
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
return mlir::success();

// `(` `...` `)`
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
return p.parseRParen();
}

// type (`,` type)* (`,` `...`)?
mlir::Type type;
if (p.parseType(type))
return mlir::failure();
params.push_back(type);
while (succeeded(p.parseOptionalComma())) {
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
return p.parseRParen();
}
if (p.parseType(type))
return mlir::failure();
params.push_back(type);
}

return p.parseRParen();
}

void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
return p.parseCommaSeparatedList(
AsmParser::Delimiter::Paren, [&]() -> mlir::ParseResult {
if (isVarArg)
return p.emitError(p.getCurrentLocation(),
"variadic `...` must be the last parameter");
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
return success();
}
mlir::Type type;
if (failed(p.parseType(type)))
return failure();
params.push_back(type);
return success();
});
}

static void printFuncTypeParams(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
p << '(';
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
Expand All @@ -326,11 +319,39 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
p << ')';
}

/// Get the C-style return type of the function, which is !cir.void if the
/// function returns nothing and the actual return type otherwise.
mlir::Type FuncType::getReturnType() const {
if (hasVoidReturn())
return cir::VoidType::get(getContext());
return getOptionalReturnType();
}

/// Get the MLIR-style return type of the function, which is an empty
/// ArrayRef if the function returns nothing and a single-element ArrayRef
/// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
if (hasVoidReturn())
return {};
// Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
// pointer to its elements and doesn't do lifetime extension. That would
// result in returning a pointer to a temporary that has gone out of scope.
return getImpl()->optionalReturnType;
}

bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
// Does the fuction type return nothing?
bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }

mlir::LogicalResult
FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
bool isVarArg) {
if (returnType && mlir::isa<cir::VoidType>(returnType)) {
emitError() << "!cir.func cannot have an explicit 'void' return type";
return mlir::failure();
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// BoolType
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CIR/IR/func.cir
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

module {
// void empty() { }
cir.func @empty() -> !cir.void {
cir.func @empty() {
cir.return
}
// CHECK: cir.func @empty() -> !cir.void {
// CHECK: cir.func @empty() {
// CHECK: cir.return
// CHECK: }

// void voidret() { return; }
cir.func @voidret() -> !cir.void {
cir.func @voidret() {
cir.return
}
// CHECK: cir.func @voidret() -> !cir.void {
// CHECK: cir.func @voidret() {
// CHECK: cir.return
// CHECK: }

Expand Down
12 changes: 6 additions & 6 deletions clang/test/CIR/IR/global.cir
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
cir.global @dp : !cir.ptr<!cir.double>
cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
cir.global @fp : !cir.ptr<!cir.func<()>>
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
}

// CHECK: cir.global @c : !cir.int<s, 8>
Expand Down Expand Up @@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
// CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
// CHECK: cir.global @dp : !cir.ptr<!cir.double>
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
4 changes: 2 additions & 2 deletions clang/test/CIR/func-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s

void empty() { }
// CHECK: cir.func @empty() -> !cir.void {
// CHECK: cir.func @empty() {
// CHECK: cir.return
// CHECK: }

void voidret() { return; }
// CHECK: cir.func @voidret() -> !cir.void {
// CHECK: cir.func @voidret() {
// CHECK: cir.return
// CHECK: }

Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/global-var-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ char **cpp;
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>

void (*fp)();
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>

int (*fpii)(int) = 0;
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>

void (*fpvar)(int, ...);
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>