Skip to content

Commit 8c6a8c3

Browse files
committed
[CIR] Add support for indirect calls
1 parent d08b176 commit 8c6a8c3

File tree

12 files changed

+188
-35
lines changed

12 files changed

+188
-35
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
225225
callee.getFunctionType().getReturnType(), operands);
226226
}
227227

228+
cir::CallOp createIndirectCallOp(mlir::Location loc,
229+
mlir::Value indirectTarget,
230+
cir::FuncType funcType,
231+
mlir::ValueRange operands) {
232+
return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
233+
operands);
234+
}
235+
228236
//===--------------------------------------------------------------------===//
229237
// Cast/Conversion Operators
230238
//===--------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,13 +1798,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
17981798
DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
17991799
let extraClassDeclaration = [{
18001800
/// Get the argument operands to the called function.
1801-
mlir::OperandRange getArgOperands() {
1802-
return getArgs();
1803-
}
1804-
1805-
mlir::MutableOperandRange getArgOperandsMutable() {
1806-
return getArgsMutable();
1807-
}
1801+
mlir::OperandRange getArgOperands();
1802+
mlir::MutableOperandRange getArgOperandsMutable();
18081803

18091804
/// Return the callee of this operation
18101805
mlir::CallInterfaceCallable getCallableForCallee() {
@@ -1826,6 +1821,9 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18261821
::mlir::Attribute removeArgAttrsAttr() { return {}; }
18271822
::mlir::Attribute removeResAttrsAttr() { return {}; }
18281823

1824+
bool isIndirect() { return !getCallee(); }
1825+
mlir::Value getIndirectCall();
1826+
18291827
void setArg(unsigned index, mlir::Value value) {
18301828
setOperand(index, value);
18311829
}
@@ -1839,16 +1837,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18391837
// the upstreaming process moves on. The verifiers is also missing for now,
18401838
// will add in the future.
18411839

1842-
dag commonArgs = (ins FlatSymbolRefAttr:$callee,
1843-
Variadic<CIR_AnyType>:$args);
1840+
dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
1841+
Variadic<CIR_AnyType>:$args);
18441842
}
18451843

18461844
def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18471845
let summary = "call a function";
18481846
let description = [{
1849-
The `cir.call` operation represents a direct call to a function that is
1850-
within the same symbol scope as the call. The callee is encoded as a symbol
1851-
reference attribute named `callee`.
1847+
The `cir.call` operation represents a function call. It could represent
1848+
either a direct call or an indirect call.
1849+
1850+
If the operation represents a direct call, the callee should be defined
1851+
within the same symbol scope as the call. The `callee` attribute contains a
1852+
symbol reference to the callee function. All operands of this operation are
1853+
arguments to the callee function.
1854+
1855+
If the operation represents an indirect call, the `callee` attribute is
1856+
empty. The first operand of this operation must be a pointer to the callee
1857+
function. All the rest operands are arguments to the callee function.
18521858

18531859
Example:
18541860

@@ -1861,13 +1867,23 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18611867
let arguments = commonArgs;
18621868

18631869
let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
1864-
"mlir::Type":$resType,
1865-
"mlir::ValueRange":$operands), [{
1870+
"mlir::Type":$resType,
1871+
"mlir::ValueRange":$operands),
1872+
[{
18661873
$_state.addOperands(operands);
18671874
$_state.addAttribute("callee", callee);
18681875
if (resType && !isa<VoidType>(resType))
18691876
$_state.addTypes(resType);
1870-
}]>];
1877+
}]>,
1878+
OpBuilder<(ins "mlir::Value":$callee, "mlir::Type":$resType,
1879+
"mlir::ValueRange":$operands),
1880+
[{
1881+
$_state.addOperands(callee);
1882+
$_state.addOperands(operands);
1883+
if (resType && !isa<VoidType>(resType))
1884+
$_state.addTypes(resType);
1885+
}]>,
1886+
];
18711887
}
18721888

18731889
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ struct MissingFeatures {
9393
static bool opCallChainCall() { return false; }
9494
static bool opCallNoPrototypeFunc() { return false; }
9595
static bool opCallMustTail() { return false; }
96-
static bool opCallIndirect() { return false; }
9796
static bool opCallVirtual() { return false; }
9897
static bool opCallInAlloca() { return false; }
9998
static bool opCallAttrs() { return false; }

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args,
9797

9898
static cir::CIRCallOpInterface
9999
emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
100+
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
100101
cir::FuncOp directFuncOp,
101102
const SmallVectorImpl<mlir::Value> &cirCallArgs) {
102103
CIRGenBuilderTy &builder = cgf.getBuilder();
@@ -105,7 +106,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
105106
assert(!cir::MissingFeatures::invokeOp());
106107

107108
assert(builder.getInsertionBlock() && "expected valid basic block");
108-
assert(!cir::MissingFeatures::opCallIndirect());
109+
110+
if (indirectFuncTy) {
111+
// TODO(cir): Set calling convention for indirect calls.
112+
assert(!cir::MissingFeatures::opCallCallConv());
113+
return builder.createIndirectCallOp(callLoc, indirectFuncVal,
114+
indirectFuncTy, cirCallArgs);
115+
}
109116

110117
return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
111118
}
@@ -134,6 +141,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
134141
cir::CIRCallOpInterface *callOp,
135142
mlir::Location loc) {
136143
QualType retTy = funcInfo.getReturnType();
144+
cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo);
137145

138146
SmallVector<mlir::Value, 16> cirCallArgs(args.size());
139147

@@ -185,12 +193,26 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
185193

186194
assert(!cir::MissingFeatures::invokeOp());
187195

188-
auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr);
189-
assert(!cir::MissingFeatures::opCallIndirect());
196+
cir::FuncType indirectFuncTy;
197+
mlir::Value indirectFuncVal;
198+
cir::FuncOp directFuncOp;
199+
if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr))
200+
directFuncOp = fnOp;
201+
else {
202+
[[maybe_unused]] auto resultTypes = calleePtr->getResultTypes();
203+
[[maybe_unused]] auto funcPtrTy =
204+
mlir::dyn_cast<cir::PointerType>(resultTypes.front());
205+
assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) &&
206+
"expected pointer to function");
207+
208+
indirectFuncTy = cirFuncTy;
209+
indirectFuncVal = calleePtr->getResult(0);
210+
}
211+
190212
assert(!cir::MissingFeatures::opCallAttrs());
191213

192-
cir::CIRCallOpInterface theCall =
193-
emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs);
214+
cir::CIRCallOpInterface theCall = emitCallLikeOp(
215+
*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);
194216

195217
if (callOp)
196218
*callOp = theCall;
@@ -290,7 +312,7 @@ void CIRGenFunction::emitCallArgs(
290312

291313
auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg,
292314
RValue emittedArg) {
293-
if (callee.hasFunctionDecl() || i >= callee.getNumParams())
315+
if (!callee.hasFunctionDecl() || i >= callee.getNumParams())
294316
return;
295317
auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>();
296318
if (!ps)

clang/lib/CIR/CodeGen/CIRGenCall.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,20 @@ class CIRGenFunction;
2525

2626
/// Abstract information about a function or function prototype.
2727
class CIRGenCalleeInfo {
28+
const clang::FunctionProtoType *calleeProtoTy;
2829
clang::GlobalDecl calleeDecl;
2930

3031
public:
31-
explicit CIRGenCalleeInfo() : calleeDecl() {}
32+
explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {}
33+
CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy,
34+
clang::GlobalDecl calleeDecl)
35+
: calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {}
3236
CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {}
37+
38+
const clang::FunctionProtoType *getCalleeFunctionProtoType() const {
39+
return calleeProtoTy;
40+
}
41+
clang::GlobalDecl getCalleeDecl() const { return calleeDecl; }
3342
};
3443

3544
class CIRGenCallee {

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,28 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
937937
return emitDirectCallee(cgm, funcDecl);
938938
}
939939

940-
cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind");
941-
return {};
940+
assert(!cir::MissingFeatures::opCallPseudoDtor());
941+
942+
// Otherwise, we have an indirect reference.
943+
mlir::Value calleePtr;
944+
QualType functionType;
945+
if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) {
946+
calleePtr = emitScalarExpr(e);
947+
functionType = ptrType->getPointeeType();
948+
} else {
949+
functionType = e->getType();
950+
calleePtr = emitLValue(e).getPointer();
951+
}
952+
assert(functionType->isFunctionType());
953+
954+
GlobalDecl gd;
955+
if (const auto *vd =
956+
dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee()))
957+
gd = GlobalDecl(vd);
958+
959+
CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd);
960+
CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp());
961+
return callee;
942962
}
943963

944964
RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,

clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
1717

1818
#include "clang/AST/CanonicalType.h"
19+
#include "clang/CIR/MissingFeatures.h"
1920
#include "llvm/ADT/FoldingSet.h"
2021
#include "llvm/Support/TrailingObjects.h"
2122

clang/lib/CIR/CodeGen/CIRGenTypes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class CIRGenTypes {
6565
/// types will be in this set.
6666
llvm::SmallPtrSet<const clang::Type *, 4> recordsBeingLaidOut;
6767

68-
llvm::SmallPtrSet<const CIRGenFunctionInfo *, 4> functionsBeingProcessed;
6968
/// Heper for convertType.
7069
mlir::Type convertFunctionTypeInternal(clang::QualType ft);
7170

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
464464
// CallOp
465465
//===----------------------------------------------------------------------===//
466466

467+
mlir::OperandRange cir::CallOp::getArgOperands() {
468+
if (isIndirect())
469+
return getArgs().drop_front(1);
470+
return getArgs();
471+
}
472+
473+
mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
474+
mlir::MutableOperandRange args = getArgsMutable();
475+
if (isIndirect())
476+
return args.slice(1, args.size() - 1);
477+
return args;
478+
}
479+
480+
mlir::Value cir::CallOp::getIndirectCall() {
481+
assert(isIndirect());
482+
return getOperand(0);
483+
}
484+
467485
/// Return the operand at index 'i'.
468486
Value cir::CallOp::getArgOperand(unsigned i) {
469-
assert(!cir::MissingFeatures::opCallIndirect());
487+
if (isIndirect())
488+
++i;
470489
return getOperand(i);
471490
}
472491

473492
/// Return the number of operands.
474493
unsigned cir::CallOp::getNumArgOperands() {
475-
assert(!cir::MissingFeatures::opCallIndirect());
494+
if (isIndirect())
495+
return this->getOperation()->getNumOperands() - 1;
476496
return this->getOperation()->getNumOperands();
477497
}
478498

@@ -483,9 +503,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
483503
mlir::FlatSymbolRefAttr calleeAttr;
484504
llvm::ArrayRef<mlir::Type> allResultTypes;
485505

506+
// If we cannot parse a string callee, it means this is an indirect call.
486507
if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
487-
.has_value())
488-
return mlir::failure();
508+
.has_value()) {
509+
OpAsmParser::UnresolvedOperand indirectVal;
510+
// Do not resolve right now, since we need to figure out the type
511+
if (parser.parseOperand(indirectVal).failed())
512+
return failure();
513+
ops.push_back(indirectVal);
514+
}
489515

490516
if (parser.parseLParen())
491517
return mlir::failure();
@@ -517,13 +543,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
517543

518544
static void printCallCommon(mlir::Operation *op,
519545
mlir::FlatSymbolRefAttr calleeSym,
546+
mlir::Value indirectCallee,
520547
mlir::OpAsmPrinter &printer) {
521548
printer << ' ';
522549

523550
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
524551
auto ops = callLikeOp.getArgOperands();
525552

526-
printer.printAttributeWithoutType(calleeSym);
553+
if (calleeSym) {
554+
// Direct calls
555+
printer.printAttributeWithoutType(calleeSym);
556+
} else {
557+
// Indirect calls
558+
assert(indirectCallee);
559+
printer << indirectCallee;
560+
}
527561
printer << "(" << ops << ")";
528562

529563
printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
@@ -539,15 +573,16 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
539573
}
540574

541575
void cir::CallOp::print(mlir::OpAsmPrinter &p) {
542-
printCallCommon(*this, getCalleeAttr(), p);
576+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
577+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
543578
}
544579

545580
static LogicalResult
546581
verifyCallCommInSymbolUses(mlir::Operation *op,
547582
SymbolTableCollection &symbolTable) {
548583
auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
549584
if (!fnAttr)
550-
return mlir::failure();
585+
return mlir::success();
551586

552587
auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
553588
if (!fn)

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,15 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
674674
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
675675
converter->convertType(fn.getFunctionType()));
676676
} else { // indirect call
677-
assert(!cir::MissingFeatures::opCallIndirect());
678-
return op->emitError("Indirect calls are NYI");
677+
assert(!op->getOperands().empty() &&
678+
"operands list must no be empty for the indirect call");
679+
auto calleeTy = op->getOperands().front().getType();
680+
auto calleePtrTy = cast<cir::PointerType>(calleeTy);
681+
auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
682+
calleeFuncTy.dump();
683+
converter->convertType(calleeFuncTy).dump();
684+
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
685+
converter->convertType(calleeFuncTy));
679686
}
680687

681688
assert(!cir::MissingFeatures::opCallLandingPad());
@@ -1501,6 +1508,15 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
15011508
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
15021509
return mlir::BFloat16Type::get(type.getContext());
15031510
});
1511+
converter.addConversion([&](cir::FuncType type) -> mlir::Type {
1512+
auto result = converter.convertType(type.getReturnType());
1513+
llvm::SmallVector<mlir::Type> arguments;
1514+
arguments.reserve(type.getNumInputs());
1515+
if (converter.convertTypes(type.getInputs(), arguments).failed())
1516+
llvm_unreachable("Failed to convert function type parameters");
1517+
auto varArg = type.isVarArg();
1518+
return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
1519+
});
15041520
converter.addConversion([&](cir::RecordType type) -> mlir::Type {
15051521
// Convert struct members.
15061522
llvm::SmallVector<mlir::Type> llvmMembers;

clang/test/CIR/CodeGen/call.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,17 @@ int f6() {
4242

4343
// LLVM-LABEL: define i32 @_Z2f6v() {
4444
// LLVM: %{{.+}} = call i32 @_Z2f5iPib(i32 2, ptr %{{.+}}, i1 false)
45+
46+
int f7(int (*ptr)(int, int)) {
47+
return ptr(1, 2);
48+
}
49+
50+
// CIR-LABEL: cir.func @_Z2f7PFiiiE
51+
// CIR: %[[#ptr:]] = cir.load %{{.+}} : !cir.ptr<!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>>, !cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>
52+
// CIR-NEXT: %[[#a:]] = cir.const #cir.int<1> : !s32i
53+
// CIR-NEXT: %[[#b:]] = cir.const #cir.int<2> : !s32i
54+
// CIR-NEXT: %{{.+}} = cir.call %[[#ptr]](%[[#a]], %[[#b]]) : (!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>, !s32i, !s32i) -> !s32i
55+
56+
// LLVM-LABEL: define i32 @_Z2f7PFiiiE
57+
// LLVM: %[[#ptr:]] = load ptr, ptr %{{.+}}
58+
// LLVM-NEXT: %{{.+}} = call i32 %[[#ptr]](i32 1, i32 2)

0 commit comments

Comments
 (0)