Skip to content

[CIR] Add support for indirect calls #139748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
callee.getFunctionType().getReturnType(), operands);
}

cir::CallOp createIndirectCallOp(mlir::Location loc,
mlir::Value indirectTarget,
cir::FuncType funcType,
mlir::ValueRange operands) {
return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
operands);
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
57 changes: 40 additions & 17 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1798,13 +1798,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
let extraClassDeclaration = [{
/// Get the argument operands to the called function.
mlir::OperandRange getArgOperands() {
return getArgs();
}

mlir::MutableOperandRange getArgOperandsMutable() {
return getArgsMutable();
}
mlir::OperandRange getArgOperands();
mlir::MutableOperandRange getArgOperandsMutable();

/// Return the callee of this operation
mlir::CallInterfaceCallable getCallableForCallee() {
Expand All @@ -1826,8 +1821,17 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
::mlir::Attribute removeArgAttrsAttr() { return {}; }
::mlir::Attribute removeResAttrsAttr() { return {}; }

bool isIndirect() { return !getCallee(); }
mlir::Value getIndirectCall();

void setArg(unsigned index, mlir::Value value) {
setOperand(index, value);
if (!isIndirect()) {
setOperand(index, value);
return;
}

// For indirect call, the operand list is shifted by one.
setOperand(index + 1, value);
}
}];

Expand All @@ -1839,16 +1843,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
// the upstreaming process moves on. The verifiers is also missing for now,
// will add in the future.

dag commonArgs = (ins FlatSymbolRefAttr:$callee,
Variadic<CIR_AnyType>:$args);
dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$args);
}

def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
let summary = "call a function";
let description = [{
The `cir.call` operation represents a direct call to a function that is
within the same symbol scope as the call. The callee is encoded as a symbol
reference attribute named `callee`.
The `cir.call` operation represents a function call. It could represent
either a direct call or an indirect call.

If the operation represents a direct call, the callee should be defined
within the same symbol scope as the call. The `callee` attribute contains a
symbol reference to the callee function. All operands of this operation are
arguments to the callee function.

If the operation represents an indirect call, the `callee` attribute is
empty. The first operand of this operation must be a pointer to the callee
function. The rest operands are arguments to the callee function.

Example:

Expand All @@ -1860,14 +1872,25 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
let results = (outs Optional<CIR_AnyType>:$result);
let arguments = commonArgs;

let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
"mlir::Type":$resType,
"mlir::ValueRange":$operands), [{
let builders = [
// Build a call op for a direct call
OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
"mlir::ValueRange":$operands), [{
assert(callee && "callee attribute is required for direct call");
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
if (resType && !isa<VoidType>(resType))
$_state.addTypes(resType);
}]>];
}]>,
// Build a call op for an indirect call
OpBuilder<(ins "mlir::Value":$calleePtr, "mlir::Type":$resType,
"mlir::ValueRange":$operands), [{
$_state.addOperands(calleePtr);
$_state.addOperands(operands);
if (resType && !isa<VoidType>(resType))
$_state.addTypes(resType);
}]>,
];
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ struct MissingFeatures {
static bool opCallChainCall() { return false; }
static bool opCallNoPrototypeFunc() { return false; }
static bool opCallMustTail() { return false; }
static bool opCallIndirect() { return false; }
static bool opCallVirtual() { return false; }
static bool opCallInAlloca() { return false; }
static bool opCallAttrs() { return false; }
Expand Down
35 changes: 29 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args,

static cir::CIRCallOpInterface
emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
cir::FuncOp directFuncOp,
const SmallVectorImpl<mlir::Value> &cirCallArgs) {
CIRGenBuilderTy &builder = cgf.getBuilder();
Expand All @@ -105,7 +106,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
assert(!cir::MissingFeatures::invokeOp());

assert(builder.getInsertionBlock() && "expected valid basic block");
assert(!cir::MissingFeatures::opCallIndirect());

if (indirectFuncTy) {
// TODO(cir): Set calling convention for indirect calls.
assert(!cir::MissingFeatures::opCallCallConv());
return builder.createIndirectCallOp(callLoc, indirectFuncVal,
indirectFuncTy, cirCallArgs);
}

return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
}
Expand Down Expand Up @@ -134,6 +141,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
cir::CIRCallOpInterface *callOp,
mlir::Location loc) {
QualType retTy = funcInfo.getReturnType();
cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo);

SmallVector<mlir::Value, 16> cirCallArgs(args.size());

Expand Down Expand Up @@ -185,12 +193,27 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,

assert(!cir::MissingFeatures::invokeOp());

auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr);
assert(!cir::MissingFeatures::opCallIndirect());
cir::FuncType indirectFuncTy;
mlir::Value indirectFuncVal;
cir::FuncOp directFuncOp;
if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr)) {
directFuncOp = fnOp;
} else {
[[maybe_unused]] mlir::ValueTypeRange<mlir::ResultRange> resultTypes =
calleePtr->getResultTypes();
[[maybe_unused]] auto funcPtrTy =
mlir::dyn_cast<cir::PointerType>(resultTypes.front());
assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) &&
"expected pointer to function");

indirectFuncTy = cirFuncTy;
indirectFuncVal = calleePtr->getResult(0);
}

assert(!cir::MissingFeatures::opCallAttrs());

cir::CIRCallOpInterface theCall =
emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs);
cir::CIRCallOpInterface theCall = emitCallLikeOp(
*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);

if (callOp)
*callOp = theCall;
Expand Down Expand Up @@ -290,7 +313,7 @@ void CIRGenFunction::emitCallArgs(

auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg,
RValue emittedArg) {
if (callee.hasFunctionDecl() || i >= callee.getNumParams())
if (!callee.hasFunctionDecl() || i >= callee.getNumParams())
return;
auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>();
if (!ps)
Expand Down
11 changes: 10 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ class CIRGenFunction;

/// Abstract information about a function or function prototype.
class CIRGenCalleeInfo {
const clang::FunctionProtoType *calleeProtoTy;
clang::GlobalDecl calleeDecl;

public:
explicit CIRGenCalleeInfo() : calleeDecl() {}
explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {}
CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy,
clang::GlobalDecl calleeDecl)
: calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {}
CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {}

const clang::FunctionProtoType *getCalleeFunctionProtoType() const {
return calleeProtoTy;
}
clang::GlobalDecl getCalleeDecl() const { return calleeDecl; }
};

class CIRGenCallee {
Expand Down
37 changes: 33 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,14 +931,43 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
implicitCast->getCastKind() == CK_BuiltinFnToFnPtr) {
return emitCallee(implicitCast->getSubExpr());
}
// When performing an indirect call through a function pointer lvalue, the
// function pointer lvalue is implicitly converted to an rvalue through an
// lvalue-to-rvalue conversion.
assert(implicitCast->getCastKind() == CK_LValueToRValue &&
"unexpected implicit cast on function pointers");
} else if (const auto *declRef = dyn_cast<DeclRefExpr>(e)) {
// Resolve direct calls.
if (const auto *funcDecl = dyn_cast<FunctionDecl>(declRef->getDecl()))
return emitDirectCallee(cgm, funcDecl);
const auto *funcDecl = cast<FunctionDecl>(declRef->getDecl());
return emitDirectCallee(cgm, funcDecl);
} else if (isa<MemberExpr>(e)) {
cgm.errorNYI(e->getSourceRange(),
"emitCallee: call to member function is NYI");
return {};
}

cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind");
return {};
assert(!cir::MissingFeatures::opCallPseudoDtor());

// Otherwise, we have an indirect reference.
mlir::Value calleePtr;
QualType functionType;
if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) {
calleePtr = emitScalarExpr(e);
functionType = ptrType->getPointeeType();
} else {
functionType = e->getType();
calleePtr = emitLValue(e).getPointer();
}
assert(functionType->isFunctionType());

GlobalDecl gd;
if (const auto *vd =
dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee()))
gd = GlobalDecl(vd);

CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd);
CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp());
return callee;
}

RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H

#include "clang/AST/CanonicalType.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/Support/TrailingObjects.h"

Expand Down
1 change: 0 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class CIRGenTypes {
/// types will be in this set.
llvm::SmallPtrSet<const clang::Type *, 4> recordsBeingLaidOut;

llvm::SmallPtrSet<const CIRGenFunctionInfo *, 4> functionsBeingProcessed;
/// Heper for convertType.
mlir::Type convertFunctionTypeInternal(clang::QualType ft);

Expand Down
53 changes: 45 additions & 8 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
// CallOp
//===----------------------------------------------------------------------===//

mlir::OperandRange cir::CallOp::getArgOperands() {
if (isIndirect())
return getArgs().drop_front(1);
return getArgs();
}

mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
mlir::MutableOperandRange args = getArgsMutable();
if (isIndirect())
return args.slice(1, args.size() - 1);
return args;
}

mlir::Value cir::CallOp::getIndirectCall() {
assert(isIndirect());
return getOperand(0);
}

/// Return the operand at index 'i'.
Value cir::CallOp::getArgOperand(unsigned i) {
assert(!cir::MissingFeatures::opCallIndirect());
if (isIndirect())
++i;
return getOperand(i);
}

/// Return the number of operands.
unsigned cir::CallOp::getNumArgOperands() {
assert(!cir::MissingFeatures::opCallIndirect());
if (isIndirect())
return this->getOperation()->getNumOperands() - 1;
return this->getOperation()->getNumOperands();
}

Expand All @@ -483,9 +503,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
mlir::FlatSymbolRefAttr calleeAttr;
llvm::ArrayRef<mlir::Type> allResultTypes;

// If we cannot parse a string callee, it means this is an indirect call.
if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
.has_value())
return mlir::failure();
.has_value()) {
OpAsmParser::UnresolvedOperand indirectVal;
// Do not resolve right now, since we need to figure out the type
if (parser.parseOperand(indirectVal).failed())
return failure();
ops.push_back(indirectVal);
}

if (parser.parseLParen())
return mlir::failure();
Expand Down Expand Up @@ -517,13 +543,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,

static void printCallCommon(mlir::Operation *op,
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer) {
printer << ' ';

auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
auto ops = callLikeOp.getArgOperands();

printer.printAttributeWithoutType(calleeSym);
if (calleeSym) {
// Direct calls
printer.printAttributeWithoutType(calleeSym);
} else {
// Indirect calls
assert(indirectCallee);
printer << indirectCallee;
}
printer << "(" << ops << ")";

printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
Expand All @@ -539,15 +573,18 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
}

void cir::CallOp::print(mlir::OpAsmPrinter &p) {
printCallCommon(*this, getCalleeAttr(), p);
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
}

static LogicalResult
verifyCallCommInSymbolUses(mlir::Operation *op,
SymbolTableCollection &symbolTable) {
auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return mlir::failure();
if (!fnAttr) {
// This is an indirect call, thus we don't have to check the symbol uses.
return mlir::success();
}

auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
if (!fn)
Expand Down
Loading