Skip to content

Commit d67d9bc

Browse files
Lancernsivan-shani
authored andcommitted
[CIR] Add support for indirect calls (llvm#139748)
This PR adds support for indirect calls to the `cir.call` operation.
1 parent 24c923c commit d67d9bc

File tree

12 files changed

+212
-40
lines changed

12 files changed

+212
-40
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
225225
callee.getFunctionType().getReturnType(), operands);
226226
}
227227

228+
cir::CallOp createIndirectCallOp(mlir::Location loc,
229+
mlir::Value indirectTarget,
230+
cir::FuncType funcType,
231+
mlir::ValueRange operands) {
232+
return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
233+
operands);
234+
}
235+
228236
//===--------------------------------------------------------------------===//
229237
// Cast/Conversion Operators
230238
//===--------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,13 +1843,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18431843
DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
18441844
let extraClassDeclaration = [{
18451845
/// Get the argument operands to the called function.
1846-
mlir::OperandRange getArgOperands() {
1847-
return getArgs();
1848-
}
1849-
1850-
mlir::MutableOperandRange getArgOperandsMutable() {
1851-
return getArgsMutable();
1852-
}
1846+
mlir::OperandRange getArgOperands();
1847+
mlir::MutableOperandRange getArgOperandsMutable();
18531848

18541849
/// Return the callee of this operation
18551850
mlir::CallInterfaceCallable getCallableForCallee() {
@@ -1871,8 +1866,17 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18711866
::mlir::Attribute removeArgAttrsAttr() { return {}; }
18721867
::mlir::Attribute removeResAttrsAttr() { return {}; }
18731868

1869+
bool isIndirect() { return !getCallee(); }
1870+
mlir::Value getIndirectCall();
1871+
18741872
void setArg(unsigned index, mlir::Value value) {
1875-
setOperand(index, value);
1873+
if (!isIndirect()) {
1874+
setOperand(index, value);
1875+
return;
1876+
}
1877+
1878+
// For indirect call, the operand list is shifted by one.
1879+
setOperand(index + 1, value);
18761880
}
18771881
}];
18781882

@@ -1884,16 +1888,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18841888
// the upstreaming process moves on. The verifiers is also missing for now,
18851889
// will add in the future.
18861890

1887-
dag commonArgs = (ins FlatSymbolRefAttr:$callee,
1888-
Variadic<CIR_AnyType>:$args);
1891+
dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
1892+
Variadic<CIR_AnyType>:$args);
18891893
}
18901894

18911895
def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18921896
let summary = "call a function";
18931897
let description = [{
1894-
The `cir.call` operation represents a direct call to a function that is
1895-
within the same symbol scope as the call. The callee is encoded as a symbol
1896-
reference attribute named `callee`.
1898+
The `cir.call` operation represents a function call. It could represent
1899+
either a direct call or an indirect call.
1900+
1901+
If the operation represents a direct call, the callee should be defined
1902+
within the same symbol scope as the call. The `callee` attribute contains a
1903+
symbol reference to the callee function. All operands of this operation are
1904+
arguments to the callee function.
1905+
1906+
If the operation represents an indirect call, the `callee` attribute is
1907+
empty. The first operand of this operation must be a pointer to the callee
1908+
function. The rest operands are arguments to the callee function.
18971909

18981910
Example:
18991911

@@ -1905,14 +1917,25 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
19051917
let results = (outs Optional<CIR_AnyType>:$result);
19061918
let arguments = commonArgs;
19071919

1908-
let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
1909-
"mlir::Type":$resType,
1910-
"mlir::ValueRange":$operands), [{
1920+
let builders = [
1921+
// Build a call op for a direct call
1922+
OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
1923+
"mlir::ValueRange":$operands), [{
1924+
assert(callee && "callee attribute is required for direct call");
19111925
$_state.addOperands(operands);
19121926
$_state.addAttribute("callee", callee);
19131927
if (resType && !isa<VoidType>(resType))
19141928
$_state.addTypes(resType);
1915-
}]>];
1929+
}]>,
1930+
// Build a call op for an indirect call
1931+
OpBuilder<(ins "mlir::Value":$calleePtr, "mlir::Type":$resType,
1932+
"mlir::ValueRange":$operands), [{
1933+
$_state.addOperands(calleePtr);
1934+
$_state.addOperands(operands);
1935+
if (resType && !isa<VoidType>(resType))
1936+
$_state.addTypes(resType);
1937+
}]>,
1938+
];
19161939
}
19171940

19181941
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ struct MissingFeatures {
9292
static bool opCallSideEffect() { return false; }
9393
static bool opCallNoPrototypeFunc() { return false; }
9494
static bool opCallMustTail() { return false; }
95-
static bool opCallIndirect() { return false; }
9695
static bool opCallVirtual() { return false; }
9796
static bool opCallInAlloca() { return false; }
9897
static bool opCallAttrs() { return false; }

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ CIRGenTypes::arrangeFunctionDeclaration(const FunctionDecl *fd) {
241241

242242
static cir::CIRCallOpInterface
243243
emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
244+
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
244245
cir::FuncOp directFuncOp,
245246
const SmallVectorImpl<mlir::Value> &cirCallArgs) {
246247
CIRGenBuilderTy &builder = cgf.getBuilder();
@@ -249,7 +250,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
249250
assert(!cir::MissingFeatures::invokeOp());
250251

251252
assert(builder.getInsertionBlock() && "expected valid basic block");
252-
assert(!cir::MissingFeatures::opCallIndirect());
253+
254+
if (indirectFuncTy) {
255+
// TODO(cir): Set calling convention for indirect calls.
256+
assert(!cir::MissingFeatures::opCallCallConv());
257+
return builder.createIndirectCallOp(callLoc, indirectFuncVal,
258+
indirectFuncTy, cirCallArgs);
259+
}
253260

254261
return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
255262
}
@@ -275,6 +282,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
275282
cir::CIRCallOpInterface *callOp,
276283
mlir::Location loc) {
277284
QualType retTy = funcInfo.getReturnType();
285+
cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo);
278286

279287
SmallVector<mlir::Value, 16> cirCallArgs(args.size());
280288

@@ -326,12 +334,27 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
326334

327335
assert(!cir::MissingFeatures::invokeOp());
328336

329-
auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr);
330-
assert(!cir::MissingFeatures::opCallIndirect());
337+
cir::FuncType indirectFuncTy;
338+
mlir::Value indirectFuncVal;
339+
cir::FuncOp directFuncOp;
340+
if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr)) {
341+
directFuncOp = fnOp;
342+
} else {
343+
[[maybe_unused]] mlir::ValueTypeRange<mlir::ResultRange> resultTypes =
344+
calleePtr->getResultTypes();
345+
[[maybe_unused]] auto funcPtrTy =
346+
mlir::dyn_cast<cir::PointerType>(resultTypes.front());
347+
assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) &&
348+
"expected pointer to function");
349+
350+
indirectFuncTy = cirFuncTy;
351+
indirectFuncVal = calleePtr->getResult(0);
352+
}
353+
331354
assert(!cir::MissingFeatures::opCallAttrs());
332355

333-
cir::CIRCallOpInterface theCall =
334-
emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs);
356+
cir::CIRCallOpInterface theCall = emitCallLikeOp(
357+
*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);
335358

336359
if (callOp)
337360
*callOp = theCall;
@@ -431,7 +454,7 @@ void CIRGenFunction::emitCallArgs(
431454

432455
auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg,
433456
RValue emittedArg) {
434-
if (callee.hasFunctionDecl() || i >= callee.getNumParams())
457+
if (!callee.hasFunctionDecl() || i >= callee.getNumParams())
435458
return;
436459
auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>();
437460
if (!ps)

clang/lib/CIR/CodeGen/CIRGenCall.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,20 @@ class CIRGenFunction;
2525

2626
/// Abstract information about a function or function prototype.
2727
class CIRGenCalleeInfo {
28+
const clang::FunctionProtoType *calleeProtoTy;
2829
clang::GlobalDecl calleeDecl;
2930

3031
public:
31-
explicit CIRGenCalleeInfo() : calleeDecl() {}
32+
explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {}
33+
CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy,
34+
clang::GlobalDecl calleeDecl)
35+
: calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {}
3236
CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {}
37+
38+
const clang::FunctionProtoType *getCalleeFunctionProtoType() const {
39+
return calleeProtoTy;
40+
}
41+
clang::GlobalDecl getCalleeDecl() const { return calleeDecl; }
3342
};
3443

3544
class CIRGenCallee {

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,14 +931,43 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
931931
implicitCast->getCastKind() == CK_BuiltinFnToFnPtr) {
932932
return emitCallee(implicitCast->getSubExpr());
933933
}
934+
// When performing an indirect call through a function pointer lvalue, the
935+
// function pointer lvalue is implicitly converted to an rvalue through an
936+
// lvalue-to-rvalue conversion.
937+
assert(implicitCast->getCastKind() == CK_LValueToRValue &&
938+
"unexpected implicit cast on function pointers");
934939
} else if (const auto *declRef = dyn_cast<DeclRefExpr>(e)) {
935940
// Resolve direct calls.
936-
if (const auto *funcDecl = dyn_cast<FunctionDecl>(declRef->getDecl()))
937-
return emitDirectCallee(cgm, funcDecl);
941+
const auto *funcDecl = cast<FunctionDecl>(declRef->getDecl());
942+
return emitDirectCallee(cgm, funcDecl);
943+
} else if (isa<MemberExpr>(e)) {
944+
cgm.errorNYI(e->getSourceRange(),
945+
"emitCallee: call to member function is NYI");
946+
return {};
938947
}
939948

940-
cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind");
941-
return {};
949+
assert(!cir::MissingFeatures::opCallPseudoDtor());
950+
951+
// Otherwise, we have an indirect reference.
952+
mlir::Value calleePtr;
953+
QualType functionType;
954+
if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) {
955+
calleePtr = emitScalarExpr(e);
956+
functionType = ptrType->getPointeeType();
957+
} else {
958+
functionType = e->getType();
959+
calleePtr = emitLValue(e).getPointer();
960+
}
961+
assert(functionType->isFunctionType());
962+
963+
GlobalDecl gd;
964+
if (const auto *vd =
965+
dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee()))
966+
gd = GlobalDecl(vd);
967+
968+
CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd);
969+
CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp());
970+
return callee;
942971
}
943972

944973
RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,

clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
1717

1818
#include "clang/AST/CanonicalType.h"
19+
#include "clang/CIR/MissingFeatures.h"
1920
#include "llvm/ADT/FoldingSet.h"
2021
#include "llvm/Support/TrailingObjects.h"
2122

clang/lib/CIR/CodeGen/CIRGenTypes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class CIRGenTypes {
6868
/// types will be in this set.
6969
llvm::SmallPtrSet<const clang::Type *, 4> recordsBeingLaidOut;
7070

71-
llvm::SmallPtrSet<const CIRGenFunctionInfo *, 4> functionsBeingProcessed;
7271
/// Heper for convertType.
7372
mlir::Type convertFunctionTypeInternal(clang::QualType ft);
7473

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -465,15 +465,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
465465
// CallOp
466466
//===----------------------------------------------------------------------===//
467467

468+
mlir::OperandRange cir::CallOp::getArgOperands() {
469+
if (isIndirect())
470+
return getArgs().drop_front(1);
471+
return getArgs();
472+
}
473+
474+
mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
475+
mlir::MutableOperandRange args = getArgsMutable();
476+
if (isIndirect())
477+
return args.slice(1, args.size() - 1);
478+
return args;
479+
}
480+
481+
mlir::Value cir::CallOp::getIndirectCall() {
482+
assert(isIndirect());
483+
return getOperand(0);
484+
}
485+
468486
/// Return the operand at index 'i'.
469487
Value cir::CallOp::getArgOperand(unsigned i) {
470-
assert(!cir::MissingFeatures::opCallIndirect());
488+
if (isIndirect())
489+
++i;
471490
return getOperand(i);
472491
}
473492

474493
/// Return the number of operands.
475494
unsigned cir::CallOp::getNumArgOperands() {
476-
assert(!cir::MissingFeatures::opCallIndirect());
495+
if (isIndirect())
496+
return this->getOperation()->getNumOperands() - 1;
477497
return this->getOperation()->getNumOperands();
478498
}
479499

@@ -484,9 +504,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
484504
mlir::FlatSymbolRefAttr calleeAttr;
485505
llvm::ArrayRef<mlir::Type> allResultTypes;
486506

507+
// If we cannot parse a string callee, it means this is an indirect call.
487508
if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
488-
.has_value())
489-
return mlir::failure();
509+
.has_value()) {
510+
OpAsmParser::UnresolvedOperand indirectVal;
511+
// Do not resolve right now, since we need to figure out the type
512+
if (parser.parseOperand(indirectVal).failed())
513+
return failure();
514+
ops.push_back(indirectVal);
515+
}
490516

491517
if (parser.parseLParen())
492518
return mlir::failure();
@@ -518,13 +544,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
518544

519545
static void printCallCommon(mlir::Operation *op,
520546
mlir::FlatSymbolRefAttr calleeSym,
547+
mlir::Value indirectCallee,
521548
mlir::OpAsmPrinter &printer) {
522549
printer << ' ';
523550

524551
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
525552
auto ops = callLikeOp.getArgOperands();
526553

527-
printer.printAttributeWithoutType(calleeSym);
554+
if (calleeSym) {
555+
// Direct calls
556+
printer.printAttributeWithoutType(calleeSym);
557+
} else {
558+
// Indirect calls
559+
assert(indirectCallee);
560+
printer << indirectCallee;
561+
}
528562
printer << "(" << ops << ")";
529563

530564
printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
@@ -540,15 +574,18 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
540574
}
541575

542576
void cir::CallOp::print(mlir::OpAsmPrinter &p) {
543-
printCallCommon(*this, getCalleeAttr(), p);
577+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
578+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
544579
}
545580

546581
static LogicalResult
547582
verifyCallCommInSymbolUses(mlir::Operation *op,
548583
SymbolTableCollection &symbolTable) {
549584
auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
550-
if (!fnAttr)
551-
return mlir::failure();
585+
if (!fnAttr) {
586+
// This is an indirect call, thus we don't have to check the symbol uses.
587+
return mlir::success();
588+
}
552589

553590
auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
554591
if (!fn)

0 commit comments

Comments
 (0)