Skip to content

Commit 6148497

Browse files
committed
[CIR] Add support for indirect calls
1 parent d08b176 commit 6148497

File tree

12 files changed

+212
-40
lines changed

12 files changed

+212
-40
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
225225
callee.getFunctionType().getReturnType(), operands);
226226
}
227227

228+
cir::CallOp createIndirectCallOp(mlir::Location loc,
229+
mlir::Value indirectTarget,
230+
cir::FuncType funcType,
231+
mlir::ValueRange operands) {
232+
return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
233+
operands);
234+
}
235+
228236
//===--------------------------------------------------------------------===//
229237
// Cast/Conversion Operators
230238
//===--------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,13 +1798,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
17981798
DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
17991799
let extraClassDeclaration = [{
18001800
/// Get the argument operands to the called function.
1801-
mlir::OperandRange getArgOperands() {
1802-
return getArgs();
1803-
}
1804-
1805-
mlir::MutableOperandRange getArgOperandsMutable() {
1806-
return getArgsMutable();
1807-
}
1801+
mlir::OperandRange getArgOperands();
1802+
mlir::MutableOperandRange getArgOperandsMutable();
18081803

18091804
/// Return the callee of this operation
18101805
mlir::CallInterfaceCallable getCallableForCallee() {
@@ -1826,8 +1821,17 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18261821
::mlir::Attribute removeArgAttrsAttr() { return {}; }
18271822
::mlir::Attribute removeResAttrsAttr() { return {}; }
18281823

1824+
bool isIndirect() { return !getCallee(); }
1825+
mlir::Value getIndirectCall();
1826+
18291827
void setArg(unsigned index, mlir::Value value) {
1830-
setOperand(index, value);
1828+
if (!isIndirect()) {
1829+
setOperand(index, value);
1830+
return;
1831+
}
1832+
1833+
// For indirect call, the operand list is shifted by one.
1834+
setOperand(index + 1, value);
18311835
}
18321836
}];
18331837

@@ -1839,16 +1843,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18391843
// the upstreaming process moves on. The verifiers is also missing for now,
18401844
// will add in the future.
18411845

1842-
dag commonArgs = (ins FlatSymbolRefAttr:$callee,
1843-
Variadic<CIR_AnyType>:$args);
1846+
dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
1847+
Variadic<CIR_AnyType>:$args);
18441848
}
18451849

18461850
def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18471851
let summary = "call a function";
18481852
let description = [{
1849-
The `cir.call` operation represents a direct call to a function that is
1850-
within the same symbol scope as the call. The callee is encoded as a symbol
1851-
reference attribute named `callee`.
1853+
The `cir.call` operation represents a function call. It could represent
1854+
either a direct call or an indirect call.
1855+
1856+
If the operation represents a direct call, the callee should be defined
1857+
within the same symbol scope as the call. The `callee` attribute contains a
1858+
symbol reference to the callee function. All operands of this operation are
1859+
arguments to the callee function.
1860+
1861+
If the operation represents an indirect call, the `callee` attribute is
1862+
empty. The first operand of this operation must be a pointer to the callee
1863+
function. The rest operands are arguments to the callee function.
18521864

18531865
Example:
18541866

@@ -1860,14 +1872,25 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18601872
let results = (outs Optional<CIR_AnyType>:$result);
18611873
let arguments = commonArgs;
18621874

1863-
let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
1864-
"mlir::Type":$resType,
1865-
"mlir::ValueRange":$operands), [{
1875+
let builders = [
1876+
// Build a call op for a direct call
1877+
OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
1878+
"mlir::ValueRange":$operands), [{
1879+
assert(callee && "callee attribute is required for direct call");
18661880
$_state.addOperands(operands);
18671881
$_state.addAttribute("callee", callee);
18681882
if (resType && !isa<VoidType>(resType))
18691883
$_state.addTypes(resType);
1870-
}]>];
1884+
}]>,
1885+
// Build a call op for an indirect call
1886+
OpBuilder<(ins "mlir::Value":$calleePtr, "mlir::Type":$resType,
1887+
"mlir::ValueRange":$operands), [{
1888+
$_state.addOperands(calleePtr);
1889+
$_state.addOperands(operands);
1890+
if (resType && !isa<VoidType>(resType))
1891+
$_state.addTypes(resType);
1892+
}]>,
1893+
];
18711894
}
18721895

18731896
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ struct MissingFeatures {
9393
static bool opCallChainCall() { return false; }
9494
static bool opCallNoPrototypeFunc() { return false; }
9595
static bool opCallMustTail() { return false; }
96-
static bool opCallIndirect() { return false; }
9796
static bool opCallVirtual() { return false; }
9897
static bool opCallInAlloca() { return false; }
9998
static bool opCallAttrs() { return false; }

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args,
9797

9898
static cir::CIRCallOpInterface
9999
emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
100+
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
100101
cir::FuncOp directFuncOp,
101102
const SmallVectorImpl<mlir::Value> &cirCallArgs) {
102103
CIRGenBuilderTy &builder = cgf.getBuilder();
@@ -105,7 +106,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
105106
assert(!cir::MissingFeatures::invokeOp());
106107

107108
assert(builder.getInsertionBlock() && "expected valid basic block");
108-
assert(!cir::MissingFeatures::opCallIndirect());
109+
110+
if (indirectFuncTy) {
111+
// TODO(cir): Set calling convention for indirect calls.
112+
assert(!cir::MissingFeatures::opCallCallConv());
113+
return builder.createIndirectCallOp(callLoc, indirectFuncVal,
114+
indirectFuncTy, cirCallArgs);
115+
}
109116

110117
return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
111118
}
@@ -134,6 +141,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
134141
cir::CIRCallOpInterface *callOp,
135142
mlir::Location loc) {
136143
QualType retTy = funcInfo.getReturnType();
144+
cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo);
137145

138146
SmallVector<mlir::Value, 16> cirCallArgs(args.size());
139147

@@ -185,12 +193,27 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
185193

186194
assert(!cir::MissingFeatures::invokeOp());
187195

188-
auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr);
189-
assert(!cir::MissingFeatures::opCallIndirect());
196+
cir::FuncType indirectFuncTy;
197+
mlir::Value indirectFuncVal;
198+
cir::FuncOp directFuncOp;
199+
if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr)) {
200+
directFuncOp = fnOp;
201+
} else {
202+
[[maybe_unused]] mlir::ValueTypeRange<mlir::ResultRange> resultTypes =
203+
calleePtr->getResultTypes();
204+
[[maybe_unused]] auto funcPtrTy =
205+
mlir::dyn_cast<cir::PointerType>(resultTypes.front());
206+
assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) &&
207+
"expected pointer to function");
208+
209+
indirectFuncTy = cirFuncTy;
210+
indirectFuncVal = calleePtr->getResult(0);
211+
}
212+
190213
assert(!cir::MissingFeatures::opCallAttrs());
191214

192-
cir::CIRCallOpInterface theCall =
193-
emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs);
215+
cir::CIRCallOpInterface theCall = emitCallLikeOp(
216+
*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);
194217

195218
if (callOp)
196219
*callOp = theCall;
@@ -290,7 +313,7 @@ void CIRGenFunction::emitCallArgs(
290313

291314
auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg,
292315
RValue emittedArg) {
293-
if (callee.hasFunctionDecl() || i >= callee.getNumParams())
316+
if (!callee.hasFunctionDecl() || i >= callee.getNumParams())
294317
return;
295318
auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>();
296319
if (!ps)

clang/lib/CIR/CodeGen/CIRGenCall.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,20 @@ class CIRGenFunction;
2525

2626
/// Abstract information about a function or function prototype.
2727
class CIRGenCalleeInfo {
28+
const clang::FunctionProtoType *calleeProtoTy;
2829
clang::GlobalDecl calleeDecl;
2930

3031
public:
31-
explicit CIRGenCalleeInfo() : calleeDecl() {}
32+
explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {}
33+
CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy,
34+
clang::GlobalDecl calleeDecl)
35+
: calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {}
3236
CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {}
37+
38+
const clang::FunctionProtoType *getCalleeFunctionProtoType() const {
39+
return calleeProtoTy;
40+
}
41+
clang::GlobalDecl getCalleeDecl() const { return calleeDecl; }
3342
};
3443

3544
class CIRGenCallee {

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,14 +931,43 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
931931
implicitCast->getCastKind() == CK_BuiltinFnToFnPtr) {
932932
return emitCallee(implicitCast->getSubExpr());
933933
}
934+
// When performing an indirect call through a function pointer lvalue, the
935+
// function pointer lvalue is implicitly converted to an rvalue through an
936+
// lvalue-to-rvalue conversion.
937+
assert(implicitCast->getCastKind() == CK_LValueToRValue &&
938+
"unexpected implicit cast on function pointers");
934939
} else if (const auto *declRef = dyn_cast<DeclRefExpr>(e)) {
935940
// Resolve direct calls.
936-
if (const auto *funcDecl = dyn_cast<FunctionDecl>(declRef->getDecl()))
937-
return emitDirectCallee(cgm, funcDecl);
941+
const auto *funcDecl = cast<FunctionDecl>(declRef->getDecl());
942+
return emitDirectCallee(cgm, funcDecl);
943+
} else if (isa<MemberExpr>(e)) {
944+
cgm.errorNYI(e->getSourceRange(),
945+
"emitCallee: call to member function is NYI");
946+
return {};
938947
}
939948

940-
cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind");
941-
return {};
949+
assert(!cir::MissingFeatures::opCallPseudoDtor());
950+
951+
// Otherwise, we have an indirect reference.
952+
mlir::Value calleePtr;
953+
QualType functionType;
954+
if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) {
955+
calleePtr = emitScalarExpr(e);
956+
functionType = ptrType->getPointeeType();
957+
} else {
958+
functionType = e->getType();
959+
calleePtr = emitLValue(e).getPointer();
960+
}
961+
assert(functionType->isFunctionType());
962+
963+
GlobalDecl gd;
964+
if (const auto *vd =
965+
dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee()))
966+
gd = GlobalDecl(vd);
967+
968+
CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd);
969+
CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp());
970+
return callee;
942971
}
943972

944973
RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,

clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
1717

1818
#include "clang/AST/CanonicalType.h"
19+
#include "clang/CIR/MissingFeatures.h"
1920
#include "llvm/ADT/FoldingSet.h"
2021
#include "llvm/Support/TrailingObjects.h"
2122

clang/lib/CIR/CodeGen/CIRGenTypes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class CIRGenTypes {
6565
/// types will be in this set.
6666
llvm::SmallPtrSet<const clang::Type *, 4> recordsBeingLaidOut;
6767

68-
llvm::SmallPtrSet<const CIRGenFunctionInfo *, 4> functionsBeingProcessed;
6968
/// Heper for convertType.
7069
mlir::Type convertFunctionTypeInternal(clang::QualType ft);
7170

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
464464
// CallOp
465465
//===----------------------------------------------------------------------===//
466466

467+
mlir::OperandRange cir::CallOp::getArgOperands() {
468+
if (isIndirect())
469+
return getArgs().drop_front(1);
470+
return getArgs();
471+
}
472+
473+
mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
474+
mlir::MutableOperandRange args = getArgsMutable();
475+
if (isIndirect())
476+
return args.slice(1, args.size() - 1);
477+
return args;
478+
}
479+
480+
mlir::Value cir::CallOp::getIndirectCall() {
481+
assert(isIndirect());
482+
return getOperand(0);
483+
}
484+
467485
/// Return the operand at index 'i'.
468486
Value cir::CallOp::getArgOperand(unsigned i) {
469-
assert(!cir::MissingFeatures::opCallIndirect());
487+
if (isIndirect())
488+
++i;
470489
return getOperand(i);
471490
}
472491

473492
/// Return the number of operands.
474493
unsigned cir::CallOp::getNumArgOperands() {
475-
assert(!cir::MissingFeatures::opCallIndirect());
494+
if (isIndirect())
495+
return this->getOperation()->getNumOperands() - 1;
476496
return this->getOperation()->getNumOperands();
477497
}
478498

@@ -483,9 +503,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
483503
mlir::FlatSymbolRefAttr calleeAttr;
484504
llvm::ArrayRef<mlir::Type> allResultTypes;
485505

506+
// If we cannot parse a string callee, it means this is an indirect call.
486507
if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
487-
.has_value())
488-
return mlir::failure();
508+
.has_value()) {
509+
OpAsmParser::UnresolvedOperand indirectVal;
510+
// Do not resolve right now, since we need to figure out the type
511+
if (parser.parseOperand(indirectVal).failed())
512+
return failure();
513+
ops.push_back(indirectVal);
514+
}
489515

490516
if (parser.parseLParen())
491517
return mlir::failure();
@@ -517,13 +543,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
517543

518544
static void printCallCommon(mlir::Operation *op,
519545
mlir::FlatSymbolRefAttr calleeSym,
546+
mlir::Value indirectCallee,
520547
mlir::OpAsmPrinter &printer) {
521548
printer << ' ';
522549

523550
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
524551
auto ops = callLikeOp.getArgOperands();
525552

526-
printer.printAttributeWithoutType(calleeSym);
553+
if (calleeSym) {
554+
// Direct calls
555+
printer.printAttributeWithoutType(calleeSym);
556+
} else {
557+
// Indirect calls
558+
assert(indirectCallee);
559+
printer << indirectCallee;
560+
}
527561
printer << "(" << ops << ")";
528562

529563
printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
@@ -539,15 +573,18 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
539573
}
540574

541575
void cir::CallOp::print(mlir::OpAsmPrinter &p) {
542-
printCallCommon(*this, getCalleeAttr(), p);
576+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
577+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
543578
}
544579

545580
static LogicalResult
546581
verifyCallCommInSymbolUses(mlir::Operation *op,
547582
SymbolTableCollection &symbolTable) {
548583
auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
549-
if (!fnAttr)
550-
return mlir::failure();
584+
if (!fnAttr) {
585+
// This is an indirect call, thus we don't have to check the symbol uses.
586+
return mlir::success();
587+
}
551588

552589
auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
553590
if (!fn)

0 commit comments

Comments
 (0)