Skip to content

Commit dbe66e5

Browse files
committed
[CIR] Cleanup support for C functions
1 parent c290f48 commit dbe66e5

File tree

7 files changed

+244
-58
lines changed

7 files changed

+244
-58
lines changed

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,30 @@
1313

1414
#include "CIRGenCall.h"
1515
#include "CIRGenFunction.h"
16+
#include "CIRGenFunctionInfo.h"
1617
#include "clang/CIR/MissingFeatures.h"
1718

1819
using namespace clang;
1920
using namespace clang::CIRGen;
2021

2122
CIRGenFunctionInfo *
2223
CIRGenFunctionInfo::create(CanQualType resultType,
23-
llvm::ArrayRef<CanQualType> argTypes) {
24+
llvm::ArrayRef<CanQualType> argTypes,
25+
RequiredArgs required) {
2426
// The first slot allocated for ArgInfo is for the return value.
2527
void *buffer = operator new(totalSizeToAlloc<ArgInfo>(argTypes.size() + 1));
2628

29+
assert(!cir::MissingFeatures::opCallCIRGenFuncInfoParamInfo());
30+
2731
CIRGenFunctionInfo *fi = new (buffer) CIRGenFunctionInfo();
28-
fi->numArgs = argTypes.size();
2932

30-
assert(!cir::MissingFeatures::opCallCIRGenFuncInfoParamInfo());
33+
fi->required = required;
34+
fi->numArgs = argTypes.size();
3135

3236
ArgInfo *argsBuffer = fi->getArgsBuffer();
3337
(argsBuffer++)->type = resultType;
3438
for (CanQualType ty : argTypes)
3539
(argsBuffer++)->type = ty;
36-
3740
assert(!cir::MissingFeatures::opCallCIRGenFuncInfoExtParamInfo());
3841

3942
return fi;
@@ -45,7 +48,7 @@ namespace {
4548
/// CIRGenFunctionInfo should be passed to actual CIR function.
4649
class ClangToCIRArgMapping {
4750
static constexpr unsigned invalidIndex = ~0U;
48-
unsigned totalNumCIRArgs;
51+
unsigned totalNumCIRArgs = 0;
4952

5053
/// Arguments of CIR function corresponding to single Clang argument.
5154
struct CIRArgs {
@@ -61,14 +64,20 @@ class ClangToCIRArgMapping {
6164

6265
public:
6366
ClangToCIRArgMapping(const ASTContext &astContext,
64-
const CIRGenFunctionInfo &funcInfo)
65-
: totalNumCIRArgs(0), argInfo(funcInfo.arg_size()) {
67+
const CIRGenFunctionInfo &funcInfo,
68+
bool onlyRequiredArgs)
69+
: argInfo(onlyRequiredArgs ? funcInfo.getNumRequiredArgs()
70+
: funcInfo.argInfoSize()) {
6671
unsigned cirArgNo = 0;
6772

6873
assert(!cir::MissingFeatures::opCallABIIndirectArg());
6974

7075
unsigned argNo = 0;
71-
for (const CIRGenFunctionInfoArgInfo &i : funcInfo.arguments()) {
76+
llvm::ArrayRef<CIRGenFunctionInfoArgInfo> argInfos(
77+
funcInfo.argInfoBegin(), onlyRequiredArgs
78+
? funcInfo.getNumRequiredArgs()
79+
: funcInfo.argInfoSize());
80+
for (const CIRGenFunctionInfoArgInfo &i : argInfos) {
7281
// Collect data about CIR arguments corresponding to Clang argument ArgNo.
7382
CIRArgs &cirArgs = argInfo[argNo];
7483

@@ -119,6 +128,63 @@ class ClangToCIRArgMapping {
119128

120129
} // namespace
121130

131+
cir::FuncType CIRGenTypes::getFunctionType(const CIRGenFunctionInfo &fi) {
132+
bool inserted = functionsBeingProcessed.insert(&fi).second;
133+
(void)inserted;
134+
assert(inserted && "Recursively being processed?");
135+
136+
mlir::Type resultType;
137+
const cir::ABIArgInfo &retInfo = fi.getReturnInfo();
138+
139+
switch (retInfo.getKind()) {
140+
case cir::ABIArgInfo::Ignore:
141+
// TODO(CIR): This should probably be the None type from the builtin
142+
// dialect.
143+
resultType = nullptr;
144+
break;
145+
case cir::ABIArgInfo::Direct:
146+
resultType = retInfo.getCoerceToType();
147+
break;
148+
}
149+
150+
ClangToCIRArgMapping cirFunctionArgs(getASTContext(), fi, true);
151+
SmallVector<mlir::Type, 8> argTypes(cirFunctionArgs.totalCIRArgs());
152+
153+
unsigned argNo = 0;
154+
llvm::ArrayRef<CIRGenFunctionInfoArgInfo> argInfos(fi.argInfoBegin(),
155+
fi.getNumRequiredArgs());
156+
for (const auto &argInfo : argInfos) {
157+
const auto &abiArgInfo = argInfo.info;
158+
159+
unsigned firstCIRArg, numCIRArgs;
160+
std::tie(firstCIRArg, numCIRArgs) = cirFunctionArgs.getCIRArgs(argNo);
161+
162+
switch (abiArgInfo.getKind()) {
163+
case cir::ABIArgInfo::Direct: {
164+
mlir::Type argType = abiArgInfo.getCoerceToType();
165+
// TODO: handle the test against llvm::RecordType from codegen
166+
assert(numCIRArgs == 1);
167+
argTypes[firstCIRArg] = argType;
168+
break;
169+
}
170+
default:
171+
cgm.errorNYI("getFunctionType: unhandled argument kind");
172+
}
173+
174+
++argNo;
175+
}
176+
assert(argNo == fi.argInfoSize() &&
177+
"Mismatch between function info and args");
178+
179+
bool erased = functionsBeingProcessed.erase(&fi);
180+
(void)erased;
181+
assert(erased && "Not in set?");
182+
183+
return cir::FuncType::get(argTypes,
184+
(resultType ? resultType : builder.getVoidTy()),
185+
fi.isVariadic());
186+
}
187+
122188
CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const {
123189
assert(!cir::MissingFeatures::opCallVirtual());
124190
return *this;
@@ -128,6 +194,9 @@ static const CIRGenFunctionInfo &
128194
arrangeFreeFunctionLikeCall(CIRGenTypes &cgt, CIRGenModule &cgm,
129195
const CallArgList &args,
130196
const FunctionType *fnType) {
197+
198+
RequiredArgs required = RequiredArgs::All;
199+
131200
if (const auto *proto = dyn_cast<FunctionProtoType>(fnType)) {
132201
if (proto->isVariadic())
133202
cgm.errorNYI("call to variadic function");
@@ -144,7 +213,7 @@ arrangeFreeFunctionLikeCall(CIRGenTypes &cgt, CIRGenModule &cgm,
144213
CanQualType retType = fnType->getReturnType()
145214
->getCanonicalTypeUnqualified()
146215
.getUnqualifiedType();
147-
return cgt.arrangeCIRFunctionInfo(retType, argTypes);
216+
return cgt.arrangeCIRFunctionInfo(retType, argTypes, required);
148217
}
149218

150219
const CIRGenFunctionInfo &
@@ -168,6 +237,23 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
168237
return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
169238
}
170239

240+
const CIRGenFunctionInfo &
241+
CIRGenTypes::arrangeFreeFunctionType(CanQual<FunctionProtoType> fpt) {
242+
SmallVector<CanQualType, 8> argTypes;
243+
for (unsigned i = 0, e = fpt->getNumParams(); i != e; ++i)
244+
argTypes.push_back(fpt->getParamType(i));
245+
RequiredArgs required = RequiredArgs::forPrototypePlus(fpt);
246+
247+
CanQualType resultType = fpt->getReturnType().getUnqualifiedType();
248+
return arrangeCIRFunctionInfo(resultType, argTypes, required);
249+
}
250+
251+
const CIRGenFunctionInfo &
252+
CIRGenTypes::arrangeFreeFunctionType(CanQual<FunctionNoProtoType> fnpt) {
253+
CanQualType resultType = fnpt->getReturnType().getUnqualifiedType();
254+
return arrangeCIRFunctionInfo(resultType, {}, RequiredArgs(0));
255+
}
256+
171257
RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
172258
const CIRGenCallee &callee,
173259
ReturnValueSlot returnValue,
@@ -177,16 +263,16 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
177263
QualType retTy = funcInfo.getReturnType();
178264
const cir::ABIArgInfo &retInfo = funcInfo.getReturnInfo();
179265

180-
ClangToCIRArgMapping cirFuncArgs(cgm.getASTContext(), funcInfo);
266+
ClangToCIRArgMapping cirFuncArgs(cgm.getASTContext(), funcInfo, false);
181267
SmallVector<mlir::Value, 16> cirCallArgs(cirFuncArgs.totalCIRArgs());
182268

183269
assert(!cir::MissingFeatures::emitLifetimeMarkers());
184270

185271
// Translate all of the arguments as necessary to match the CIR lowering.
186-
assert(funcInfo.arg_size() == args.size() &&
272+
assert(funcInfo.argInfoSize() == args.size() &&
187273
"Mismatch between function signature & arguments.");
188274
unsigned argNo = 0;
189-
for (const auto &[arg, argInfo] : llvm::zip(args, funcInfo.arguments())) {
275+
for (const auto &[arg, argInfo] : llvm::zip(args, funcInfo.argInfos())) {
190276
// Insert a padding argument to ensure proper alignment.
191277
assert(!cir::MissingFeatures::opCallPaddingArgs());
192278

clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,68 @@ struct CIRGenFunctionInfoArgInfo {
2727
cir::ABIArgInfo info;
2828
};
2929

30+
/// A class for recording the number of arguments that a function signature
31+
/// requires.
32+
class RequiredArgs {
33+
/// The number of required arguments, or ~0 if the signature does not permit
34+
/// optional arguments.
35+
unsigned numRequired;
36+
37+
public:
38+
enum All_t { All };
39+
40+
RequiredArgs(All_t _) : numRequired(~0U) {}
41+
explicit RequiredArgs(unsigned n) : numRequired(n) { assert(n != ~0U); }
42+
43+
unsigned getOpaqueData() const { return numRequired; }
44+
45+
bool allowsOptionalArgs() const { return numRequired != ~0U; }
46+
47+
/// Compute the arguments required by the given formal prototype, given that
48+
/// there may be some additional, non-formal arguments in play.
49+
///
50+
/// If FD is not null, this will consider pass_object_size params in FD.
51+
static RequiredArgs
52+
forPrototypePlus(const clang::FunctionProtoType *prototype) {
53+
if (!prototype->isVariadic())
54+
return All;
55+
56+
if (prototype->hasExtParameterInfos())
57+
llvm_unreachable("NYI");
58+
59+
return RequiredArgs(prototype->getNumParams());
60+
}
61+
62+
static RequiredArgs
63+
forPrototypePlus(clang::CanQual<clang::FunctionProtoType> prototype) {
64+
return forPrototypePlus(prototype.getTypePtr());
65+
}
66+
67+
unsigned getNumRequiredArgs() const {
68+
assert(allowsOptionalArgs());
69+
return numRequired;
70+
}
71+
};
72+
3073
class CIRGenFunctionInfo final
3174
: public llvm::FoldingSetNode,
3275
private llvm::TrailingObjects<CIRGenFunctionInfo,
3376
CIRGenFunctionInfoArgInfo> {
3477
using ArgInfo = CIRGenFunctionInfoArgInfo;
3578

79+
RequiredArgs required;
80+
3681
unsigned numArgs;
3782

3883
ArgInfo *getArgsBuffer() { return getTrailingObjects<ArgInfo>(); }
3984
const ArgInfo *getArgsBuffer() const { return getTrailingObjects<ArgInfo>(); }
4085

86+
CIRGenFunctionInfo() : required(RequiredArgs::All) {}
87+
4188
public:
4289
static CIRGenFunctionInfo *create(CanQualType resultType,
43-
llvm::ArrayRef<CanQualType> argTypes);
90+
llvm::ArrayRef<CanQualType> argTypes,
91+
RequiredArgs required);
4492

4593
void operator delete(void *p) { ::operator delete(p); }
4694

@@ -53,35 +101,52 @@ class CIRGenFunctionInfo final
53101

54102
// This function has to be CamelCase because llvm::FoldingSet requires so.
55103
// NOLINTNEXTLINE(readability-identifier-naming)
56-
static void Profile(llvm::FoldingSetNodeID &id, CanQualType resultType,
57-
llvm::ArrayRef<clang::CanQualType> argTypes) {
104+
static void Profile(llvm::FoldingSetNodeID &id, RequiredArgs required,
105+
CanQualType resultType,
106+
llvm::ArrayRef<CanQualType> argTypes) {
107+
id.AddBoolean(required.getOpaqueData());
58108
resultType.Profile(id);
59-
for (auto i : argTypes)
60-
i.Profile(id);
109+
for (const CanQualType &arg : argTypes)
110+
arg.Profile(id);
61111
}
62112

63-
void Profile(llvm::FoldingSetNodeID &id) { getReturnType().Profile(id); }
64-
65-
llvm::MutableArrayRef<ArgInfo> arguments() {
66-
return llvm::MutableArrayRef<ArgInfo>(arg_begin(), numArgs);
67-
}
68-
llvm::ArrayRef<ArgInfo> arguments() const {
69-
return llvm::ArrayRef<ArgInfo>(arg_begin(), numArgs);
113+
// NOLINTNEXTLINE(readability-identifier-naming)
114+
void Profile(llvm::FoldingSetNodeID &id) {
115+
id.AddBoolean(required.getOpaqueData());
116+
getReturnType().Profile(id);
117+
for (const ArgInfo &argInfo : argInfos())
118+
argInfo.type.Profile(id);
70119
}
71120

72-
const_arg_iterator arg_begin() const { return getArgsBuffer() + 1; }
73-
const_arg_iterator arg_end() const { return getArgsBuffer() + 1 + numArgs; }
74-
arg_iterator arg_begin() { return getArgsBuffer() + 1; }
75-
arg_iterator arg_end() { return getArgsBuffer() + 1 + numArgs; }
76-
77-
unsigned arg_size() const { return numArgs; }
78-
79121
CanQualType getReturnType() const { return getArgsBuffer()[0].type; }
80122

81123
cir::ABIArgInfo &getReturnInfo() { return getArgsBuffer()[0].info; }
82124
const cir::ABIArgInfo &getReturnInfo() const {
83125
return getArgsBuffer()[0].info;
84126
}
127+
128+
const_arg_iterator argInfoBegin() const { return getArgsBuffer() + 1; }
129+
const_arg_iterator argInfoEnd() const {
130+
return getArgsBuffer() + 1 + numArgs;
131+
}
132+
arg_iterator argInfoBegin() { return getArgsBuffer() + 1; }
133+
arg_iterator argInfoEnd() { return getArgsBuffer() + 1 + numArgs; }
134+
135+
unsigned argInfoSize() const { return numArgs; }
136+
137+
llvm::MutableArrayRef<ArgInfo> argInfos() {
138+
return llvm::MutableArrayRef<ArgInfo>(argInfoBegin(), numArgs);
139+
}
140+
llvm::ArrayRef<ArgInfo> argInfos() const {
141+
return llvm::ArrayRef<ArgInfo>(argInfoBegin(), numArgs);
142+
}
143+
144+
bool isVariadic() const { return required.allowsOptionalArgs(); }
145+
RequiredArgs getRequiredArgs() const { return required; }
146+
unsigned getNumRequiredArgs() const {
147+
return isVariadic() ? getRequiredArgs().getNumRequiredArgs()
148+
: argInfoSize();
149+
}
85150
};
86151

87152
} // namespace clang::CIRGen

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2424
#include "clang/CIR/MissingFeatures.h"
2525

26+
#include "CIRGenFunctionInfo.h"
2627
#include "mlir/IR/BuiltinOps.h"
2728
#include "mlir/IR/Location.h"
2829
#include "mlir/IR/MLIRContext.h"
@@ -247,8 +248,21 @@ void CIRGenModule::emitGlobalFunctionDefinition(clang::GlobalDecl gd,
247248
"function definition with a non-identifier for a name");
248249
return;
249250
}
250-
cir::FuncType funcType =
251-
cast<cir::FuncType>(convertType(funcDecl->getType()));
251+
252+
cir::FuncType funcType;
253+
// TODO: Move this to arrangeFunctionDeclaration when it is
254+
// implemented.
255+
// When declaring a function without a prototype, always use a
256+
// non-variadic type.
257+
if (CanQual<FunctionNoProtoType> noProto =
258+
funcDecl->getType()
259+
->getCanonicalTypeUnqualified()
260+
.getAs<FunctionNoProtoType>()) {
261+
const CIRGenFunctionInfo &fi = getTypes().arrangeCIRFunctionInfo(
262+
noProto->getReturnType(), {}, RequiredArgs::All);
263+
funcType = getTypes().getFunctionType(fi);
264+
} else
265+
funcType = cast<cir::FuncType>(convertType(funcDecl->getType()));
252266

253267
cir::FuncOp funcOp = dyn_cast_if_present<cir::FuncOp>(op);
254268
if (!funcOp || funcOp.getFunctionType() != funcType) {

0 commit comments

Comments
 (0)