Skip to content

Commit 556eb82

Browse files
authored
[CIR] Function type return type improvements (#128787)
When a C or C++ function has a return type of `void`, the function type is now represented in MLIR as having no return type rather than having a return type of `!cir.void`. This avoids breaking MLIR invariants that require the number of return types and the number of return values to match. Change the assembly format for `cir::FuncType` from having a leading return type to having a trailing return type. In other words, change ``` !cir.func<!returnType (!argTypes)> ``` to ``` !cir.func<(!argTypes) -> !returnType)> ``` Unless the function returns `void`, in which case change ``` !cir.func<!cir.void (!argTypes)> ``` to ``` !cir.func<(!argTypes)> ```
1 parent 110b77f commit 556eb82

File tree

9 files changed

+105
-84
lines changed

9 files changed

+105
-84
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,6 @@ def FuncOp : CIR_Op<"func", [
365365
return getFunctionType().getReturnTypes();
366366
}
367367

368-
/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
369-
/// the 'type' attribute is present and checks if it holds a function type.
370-
/// Ensures getType, getNumFuncArguments, and getNumFuncResults can be
371-
/// called safely.
372-
llvm::LogicalResult verifyType();
373-
374368
//===------------------------------------------------------------------===//
375369
// SymbolOpInterface Methods
376370
//===------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,32 +287,44 @@ def CIR_BoolType :
287287
def CIR_FuncType : CIR_Type<"Func", "func"> {
288288
let summary = "CIR function type";
289289
let description = [{
290-
The `!cir.func` is a function type. It consists of a single return type, a
291-
list of parameter types and can optionally be variadic.
290+
The `!cir.func` is a function type. It consists of an optional return type,
291+
a list of parameter types and can optionally be variadic.
292292

293293
Example:
294294

295295
```mlir
296-
!cir.func<!bool ()>
297-
!cir.func<!s32i (!s8i, !s8i)>
298-
!cir.func<!s32i (!s32i, ...)>
296+
!cir.func<()>
297+
!cir.func<() -> bool>
298+
!cir.func<(!s8i, !s8i)>
299+
!cir.func<(!s8i, !s8i) -> !s32i>
300+
!cir.func<(!s32i, ...) -> !s32i>
299301
```
300302
}];
301303

302304
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
303-
"mlir::Type":$returnType, "bool":$varArg);
305+
OptionalParameter<"mlir::Type">:$optionalReturnType,
306+
"bool":$varArg);
307+
// Use a custom parser to handle argument types with variadic elipsis.
304308
let assemblyFormat = [{
305-
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
309+
`<` custom<FuncTypeParams>($inputs, $varArg) (`->` $optionalReturnType^)? `>`
306310
}];
307311

308312
let builders = [
313+
// Create a FuncType, converting the return type from C-style to
314+
// MLIR-style. If the given return type is `cir::VoidType`, ignore it
315+
// and create the FuncType with no return type, which is how MLIR
316+
// represents function types.
309317
TypeBuilderWithInferredContext<(ins
310318
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
311319
CArg<"bool", "false">:$isVarArg), [{
312-
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
320+
return $_get(returnType.getContext(), inputs,
321+
mlir::isa<cir::VoidType>(returnType) ? nullptr : returnType,
322+
isVarArg);
313323
}]>
314324
];
315325

326+
let genVerifyDecl = 1;
327+
316328
let extraClassDeclaration = [{
317329
/// Returns whether the function is variadic.
318330
bool isVarArg() const { return getVarArg(); }
@@ -323,12 +335,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
323335
/// Returns the number of arguments to the function.
324336
unsigned getNumInputs() const { return getInputs().size(); }
325337

326-
/// Returns the result type of the function as an ArrayRef, enabling better
327-
/// integration with generic MLIR utilities.
338+
/// Get the C-style return type of the function, which is !cir.void if the
339+
/// function returns nothing and the actual return type otherwise.
340+
mlir::Type getReturnType() const;
341+
342+
/// Get the MLIR-style return type of the function, which is an empty
343+
/// ArrayRef if the function returns nothing and a single-element ArrayRef
344+
/// with the actual return type otherwise.
328345
llvm::ArrayRef<mlir::Type> getReturnTypes() const;
329346

330-
/// Returns whether the function is returns void.
331-
bool isVoid() const;
347+
/// Does the function type return nothing?
348+
bool hasVoidReturn() const;
332349

333350
/// Returns a clone of this function type with the given argument
334351
/// and result types.

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) {
6060
mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) {
6161
assert(qft.isCanonical());
6262
const FunctionType *ft = cast<FunctionType>(qft.getTypePtr());
63-
// First, check whether we can build the full fucntion type. If the function
63+
// First, check whether we can build the full function type. If the function
6464
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
6565
// the function type.
6666
if (!isFuncTypeConvertible(ft)) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -416,17 +416,6 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
416416
}
417417
}
418418

419-
// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
420-
// attribute is present. This can check for preconditions of the
421-
// getNumArguments hook not failing.
422-
LogicalResult cir::FuncOp::verifyType() {
423-
auto type = getFunctionType();
424-
if (!isa<cir::FuncType>(type))
425-
return emitOpError("requires '" + getFunctionTypeAttrName().str() +
426-
"' attribute of function type");
427-
return success();
428-
}
429-
430419
// TODO(CIR): The properties of functions that require verification haven't
431420
// been implemented yet.
432421
mlir::LogicalResult cir::FuncOp::verify() { return success(); }

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
//===----------------------------------------------------------------------===//
2222

2323
static mlir::ParseResult
24-
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
25-
bool &isVarArg);
26-
static void printFuncTypeArgs(mlir::AsmPrinter &p,
27-
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
24+
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
25+
bool &isVarArg);
26+
static void printFuncTypeParams(mlir::AsmPrinter &p,
27+
mlir::ArrayRef<mlir::Type> params,
28+
bool isVarArg);
2829

2930
//===----------------------------------------------------------------------===//
3031
// Get autogenerated stuff
@@ -282,40 +283,32 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
282283
return get(llvm::to_vector(inputs), results[0], isVarArg());
283284
}
284285

285-
mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
286-
llvm::SmallVector<mlir::Type> &params,
287-
bool &isVarArg) {
286+
// Custom parser that parses function parameters of form `(<type>*, ...)`.
287+
static mlir::ParseResult
288+
parseFuncTypeParams(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
289+
bool &isVarArg) {
288290
isVarArg = false;
289-
// `(` `)`
290-
if (succeeded(p.parseOptionalRParen()))
291-
return mlir::success();
292-
293-
// `(` `...` `)`
294-
if (succeeded(p.parseOptionalEllipsis())) {
295-
isVarArg = true;
296-
return p.parseRParen();
297-
}
298-
299-
// type (`,` type)* (`,` `...`)?
300-
mlir::Type type;
301-
if (p.parseType(type))
302-
return mlir::failure();
303-
params.push_back(type);
304-
while (succeeded(p.parseOptionalComma())) {
305-
if (succeeded(p.parseOptionalEllipsis())) {
306-
isVarArg = true;
307-
return p.parseRParen();
308-
}
309-
if (p.parseType(type))
310-
return mlir::failure();
311-
params.push_back(type);
312-
}
313-
314-
return p.parseRParen();
315-
}
316-
317-
void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
318-
bool isVarArg) {
291+
return p.parseCommaSeparatedList(
292+
AsmParser::Delimiter::Paren, [&]() -> mlir::ParseResult {
293+
if (isVarArg)
294+
return p.emitError(p.getCurrentLocation(),
295+
"variadic `...` must be the last parameter");
296+
if (succeeded(p.parseOptionalEllipsis())) {
297+
isVarArg = true;
298+
return success();
299+
}
300+
mlir::Type type;
301+
if (failed(p.parseType(type)))
302+
return failure();
303+
params.push_back(type);
304+
return success();
305+
});
306+
}
307+
308+
static void printFuncTypeParams(mlir::AsmPrinter &p,
309+
mlir::ArrayRef<mlir::Type> params,
310+
bool isVarArg) {
311+
p << '(';
319312
llvm::interleaveComma(params, p,
320313
[&p](mlir::Type type) { p.printType(type); });
321314
if (isVarArg) {
@@ -326,11 +319,39 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
326319
p << ')';
327320
}
328321

322+
/// Get the C-style return type of the function, which is !cir.void if the
323+
/// function returns nothing and the actual return type otherwise.
324+
mlir::Type FuncType::getReturnType() const {
325+
if (hasVoidReturn())
326+
return cir::VoidType::get(getContext());
327+
return getOptionalReturnType();
328+
}
329+
330+
/// Get the MLIR-style return type of the function, which is an empty
331+
/// ArrayRef if the function returns nothing and a single-element ArrayRef
332+
/// with the actual return type otherwise.
329333
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
330-
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
334+
if (hasVoidReturn())
335+
return {};
336+
// Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
337+
// pointer to its elements and doesn't do lifetime extension. That would
338+
// result in returning a pointer to a temporary that has gone out of scope.
339+
return getImpl()->optionalReturnType;
331340
}
332341

333-
bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
342+
// Does the fuction type return nothing?
343+
bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }
344+
345+
mlir::LogicalResult
346+
FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
347+
llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
348+
bool isVarArg) {
349+
if (returnType && mlir::isa<cir::VoidType>(returnType)) {
350+
emitError() << "!cir.func cannot have an explicit 'void' return type";
351+
return mlir::failure();
352+
}
353+
return mlir::success();
354+
}
334355

335356
//===----------------------------------------------------------------------===//
336357
// BoolType

clang/test/CIR/IR/func.cir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22

33
module {
44
// void empty() { }
5-
cir.func @empty() -> !cir.void {
5+
cir.func @empty() {
66
cir.return
77
}
8-
// CHECK: cir.func @empty() -> !cir.void {
8+
// CHECK: cir.func @empty() {
99
// CHECK: cir.return
1010
// CHECK: }
1111

1212
// void voidret() { return; }
13-
cir.func @voidret() -> !cir.void {
13+
cir.func @voidret() {
1414
cir.return
1515
}
16-
// CHECK: cir.func @voidret() -> !cir.void {
16+
// CHECK: cir.func @voidret() {
1717
// CHECK: cir.return
1818
// CHECK: }
1919

clang/test/CIR/IR/global.cir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
3030
cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
3131
cir.global @dp : !cir.ptr<!cir.double>
3232
cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
33-
cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
34-
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
35-
cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
33+
cir.global @fp : !cir.ptr<!cir.func<()>>
34+
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
35+
cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
3636
}
3737

3838
// CHECK: cir.global @c : !cir.int<s, 8>
@@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
6464
// CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
6565
// CHECK: cir.global @dp : !cir.ptr<!cir.double>
6666
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
67-
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
68-
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
69-
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
67+
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
68+
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
69+
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>

clang/test/CIR/func-simple.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s
33

44
void empty() { }
5-
// CHECK: cir.func @empty() -> !cir.void {
5+
// CHECK: cir.func @empty() {
66
// CHECK: cir.return
77
// CHECK: }
88

99
void voidret() { return; }
10-
// CHECK: cir.func @voidret() -> !cir.void {
10+
// CHECK: cir.func @voidret() {
1111
// CHECK: cir.return
1212
// CHECK: }
1313

clang/test/CIR/global-var-simple.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ char **cpp;
9292
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
9393

9494
void (*fp)();
95-
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
95+
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
9696

9797
int (*fpii)(int) = 0;
98-
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
98+
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
9999

100100
void (*fpvar)(int, ...);
101-
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
101+
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>

0 commit comments

Comments
 (0)