Skip to content

Commit 99e1308

Browse files
authored
[mlir][LLVM] handle argument and result attributes in llvm.call and llvm.invoke (#123177)
Update llvm.call/llvm.invoke pretty printer/parser and the llvm ir import/export to deal with the argument and result attributes. This patch is made on top of PR 123176 that modified the CallOpInterface and added the argument and result attributes to llvm.call and llvm.invoke without doing anything with them. RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107
1 parent d78b5ce commit 99e1308

File tree

13 files changed

+345
-46
lines changed

13 files changed

+345
-46
lines changed

llvm/include/llvm/IR/InstrTypes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
14901490
Attrs = Attrs.addRetAttribute(getContext(), Attr);
14911491
}
14921492

1493+
/// Adds attributes to the return value.
1494+
void addRetAttrs(const AttrBuilder &B) {
1495+
Attrs = Attrs.addRetAttributes(getContext(), B);
1496+
}
1497+
14931498
/// Adds the attribute to the indicated argument
14941499
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
14951500
assert(ArgNo < arg_size() && "Out of bounds");
@@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
15021507
Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
15031508
}
15041509

1510+
/// Adds attributes to the indicated argument
1511+
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
1512+
assert(ArgNo < arg_size() && "Out of bounds");
1513+
Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
1514+
}
1515+
15051516
/// removes the attribute from the list of attributes.
15061517
void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
15071518
Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,18 @@ class ModuleImport {
341341
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
342342
/// Returns the callee name, or an empty symbol if the call is not direct.
343343
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
344-
/// Converts the parameter attributes attached to `func` and adds them to
345-
/// the `funcOp`.
344+
/// Converts the parameter and result attributes attached to `func` and adds
345+
/// them to the `funcOp`.
346346
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
347347
OpBuilder &builder);
348348
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
349349
/// DictionaryAttr for the LLVM dialect.
350350
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
351351
OpBuilder &builder);
352+
/// Converts the parameter and result attributes attached to `call` and adds
353+
/// them to the `callOp`.
354+
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
355+
OpBuilder &builder);
352356
/// Converts the attributes attached to `inst` and adds them to the `op`.
353357
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
354358
/// Converts the attributes attached to `inst` and adds them to the `op`.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ class ModuleTranslation {
235235
/*recordInsertions=*/false);
236236
}
237237

238+
/// Translates parameter attributes of a call and adds them to the returned
239+
/// AttrBuilder. Returns failure if any of the translations failed.
240+
FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOpInterface callOp,
241+
DictionaryAttr paramAttrs);
242+
238243
/// Gets the named metadata in the LLVM IR module being constructed, creating
239244
/// it if it does not exist.
240245
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
@@ -359,8 +364,8 @@ class ModuleTranslation {
359364
convertDialectAttributes(Operation *op,
360365
ArrayRef<llvm::Instruction *> instructions);
361366

362-
/// Translates parameter attributes and adds them to the returned AttrBuilder.
363-
/// Returns failure if any of the translations failed.
367+
/// Translates parameter attributes of a function and adds them to the
368+
/// returned AttrBuilder. Returns failure if any of the translations failed.
364369
FailureOr<llvm::AttrBuilder>
365370
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
366371

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,55 +1335,66 @@ void CallOp::print(OpAsmPrinter &p) {
13351335
getVarCalleeTypeAttrName(), getCConvAttrName(),
13361336
getOperandSegmentSizesAttrName(),
13371337
getOpBundleSizesAttrName(),
1338-
getOpBundleTagsAttrName()});
1338+
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1339+
getResAttrsAttrName()});
13391340

13401341
p << " : ";
13411342
if (!isDirect)
13421343
p << getOperand(0).getType() << ", ";
13431344

1344-
// Reconstruct the function MLIR function type from operand and result types.
1345-
p.printFunctionalType(args.getTypes(), getResultTypes());
1345+
// Reconstruct the MLIR function type from operand and result types.
1346+
call_interface_impl::printFunctionSignature(
1347+
p, args.getTypes(), getArgAttrsAttr(),
1348+
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
13461349
}
13471350

13481351
/// Parses the type of a call operation and resolves the operands if the parsing
13491352
/// succeeds. Returns failure otherwise.
13501353
static ParseResult parseCallTypeAndResolveOperands(
13511354
OpAsmParser &parser, OperationState &result, bool isDirect,
1352-
ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1355+
ArrayRef<OpAsmParser::UnresolvedOperand> operands,
1356+
SmallVectorImpl<DictionaryAttr> &argAttrs,
1357+
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
13531358
SMLoc trailingTypesLoc = parser.getCurrentLocation();
13541359
SmallVector<Type> types;
1355-
if (parser.parseColonTypeList(types))
1360+
if (parser.parseColon())
13561361
return failure();
1357-
1358-
if (isDirect && types.size() != 1)
1359-
return parser.emitError(trailingTypesLoc,
1360-
"expected direct call to have 1 trailing type");
1361-
if (!isDirect && types.size() != 2)
1362-
return parser.emitError(trailingTypesLoc,
1363-
"expected indirect call to have 2 trailing types");
1364-
1365-
auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
1366-
if (!funcType)
1362+
if (!isDirect) {
1363+
types.emplace_back();
1364+
if (parser.parseType(types.back()))
1365+
return failure();
1366+
if (parser.parseOptionalComma())
1367+
return parser.emitError(
1368+
trailingTypesLoc, "expected indirect call to have 2 trailing types");
1369+
}
1370+
SmallVector<Type> argTypes;
1371+
SmallVector<Type> resTypes;
1372+
if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1373+
resTypes, resultAttrs)) {
1374+
if (isDirect)
1375+
return parser.emitError(trailingTypesLoc,
1376+
"expected direct call to have 1 trailing types");
13671377
return parser.emitError(trailingTypesLoc,
13681378
"expected trailing function type");
1369-
if (funcType.getNumResults() > 1)
1379+
}
1380+
1381+
if (resTypes.size() > 1)
13701382
return parser.emitError(trailingTypesLoc,
13711383
"expected function with 0 or 1 result");
1372-
if (funcType.getNumResults() == 1 &&
1373-
llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
1384+
if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
13741385
return parser.emitError(trailingTypesLoc,
13751386
"expected a non-void result type");
13761387

13771388
// The head element of the types list matches the callee type for
13781389
// indirect calls, while the types list is emtpy for direct calls.
13791390
// Append the function input types to resolve the call operation
13801391
// operands.
1381-
llvm::append_range(types, funcType.getInputs());
1392+
llvm::append_range(types, argTypes);
13821393
if (parser.resolveOperands(operands, types, parser.getNameLoc(),
13831394
result.operands))
13841395
return failure();
1385-
if (funcType.getNumResults() != 0)
1386-
result.addTypes(funcType.getResults());
1396+
if (resTypes.size() != 0)
1397+
result.addTypes(resTypes);
13871398

13881399
return success();
13891400
}
@@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
14971508
return failure();
14981509

14991510
// Parse the trailing type list and resolve the operands.
1500-
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1511+
SmallVector<DictionaryAttr> argAttrs;
1512+
SmallVector<DictionaryAttr> resultAttrs;
1513+
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1514+
argAttrs, resultAttrs))
15011515
return failure();
1516+
call_interface_impl::addArgAndResultAttrs(
1517+
parser.getBuilder(), result, argAttrs, resultAttrs,
1518+
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
15021519
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
15031520
opBundleOperandTypes,
15041521
getOpBundleSizesAttrName(result.name)))
@@ -1643,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
16431660
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
16441661
getCConvAttrName(), getVarCalleeTypeAttrName(),
16451662
getOpBundleSizesAttrName(),
1646-
getOpBundleTagsAttrName()});
1663+
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1664+
getResAttrsAttrName()});
16471665

16481666
p << " : ";
16491667
if (!isDirect)
16501668
p << getOperand(0).getType() << ", ";
1651-
p.printFunctionalType(
1652-
llvm::drop_begin(getCalleeOperands().getTypes(), isDirect ? 0 : 1),
1653-
getResultTypes());
1669+
call_interface_impl::printFunctionSignature(
1670+
p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
1671+
getArgAttrsAttr(),
1672+
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
16541673
}
16551674

16561675
// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
@@ -1659,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
16591678
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
16601679
// ( `vararg(` var-callee-type `)` )?
16611680
// ( `[` op-bundles-list `]` )?
1662-
// attribute-dict? `:` (type `,`)? function-type
1681+
// attribute-dict? `:` (type `,`)?
1682+
// function-type-with-argument-attributes
16631683
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
16641684
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
16651685
SymbolRefAttr funcAttr;
@@ -1721,8 +1741,15 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
17211741
return failure();
17221742

17231743
// Parse the trailing type list and resolve the function operands.
1724-
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1744+
SmallVector<DictionaryAttr> argAttrs;
1745+
SmallVector<DictionaryAttr> resultAttrs;
1746+
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1747+
argAttrs, resultAttrs))
17251748
return failure();
1749+
call_interface_impl::addArgAndResultAttrs(
1750+
parser.getBuilder(), result, argAttrs, resultAttrs,
1751+
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1752+
17261753
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
17271754
opBundleOperandTypes,
17281755
getOpBundleSizesAttrName(result.name)))

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,39 @@ static void convertLinkerOptionsOp(ArrayAttr options,
224224
linkerMDNode->addOperand(listMDNode);
225225
}
226226

227+
static LogicalResult
228+
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
229+
LLVM::ModuleTranslation &moduleTranslation) {
230+
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
231+
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
232+
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
233+
!argAttrs.empty()) {
234+
FailureOr<llvm::AttrBuilder> attrBuilder =
235+
moduleTranslation.convertParameterAttrs(callOp, argAttrs);
236+
if (failed(attrBuilder))
237+
return failure();
238+
call->addParamAttrs(argIdx, *attrBuilder);
239+
}
240+
}
241+
}
242+
243+
ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
244+
if (resAttrsArray && resAttrsArray.size() > 0) {
245+
if (resAttrsArray.size() != 1)
246+
return mlir::emitError(callOp.getLoc(),
247+
"llvm.func cannot have multiple results");
248+
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
249+
!resAttrs.empty()) {
250+
FailureOr<llvm::AttrBuilder> attrBuilder =
251+
moduleTranslation.convertParameterAttrs(callOp, resAttrs);
252+
if (failed(attrBuilder))
253+
return failure();
254+
call->addRetAttrs(*attrBuilder);
255+
}
256+
}
257+
return success();
258+
}
259+
227260
static LogicalResult
228261
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
229262
LLVM::ModuleTranslation &moduleTranslation) {
@@ -265,6 +298,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
265298
if (callOp.getWillReturnAttr())
266299
call->addFnAttr(llvm::Attribute::WillReturn);
267300

301+
if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
302+
return failure();
303+
268304
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
269305
llvm::MemoryEffects memEffects =
270306
llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
@@ -372,6 +408,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
372408
operandsRef.drop_front(), opBundles);
373409
}
374410
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
411+
if (failed(
412+
convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
413+
return failure();
375414
moduleTranslation.mapBranch(invOp, result);
376415
// InvokeOp can only have 0 or 1 result
377416
if (invOp->getNumResults() != 0) {

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17561756
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
17571757
if (failed(convertCallAttributes(callInst, callOp)))
17581758
return failure();
1759+
// Handle parameter and result attributes.
1760+
convertParameterAttributes(callInst, callOp, builder);
17591761
return callOp.getOperation();
17601762
}();
17611763

@@ -1836,6 +1838,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
18361838
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
18371839
return failure();
18381840

1841+
// Handle parameter and result attributes.
1842+
convertParameterAttributes(invokeInst, invokeOp, builder);
1843+
18391844
if (!invokeInst->getType()->isVoidTy())
18401845
mapValue(inst, invokeOp.getResults().front());
18411846
else
@@ -2199,6 +2204,37 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
21992204
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
22002205
}
22012206

2207+
void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
2208+
CallOpInterface callOp,
2209+
OpBuilder &builder) {
2210+
llvm::AttributeList llvmAttrs = call->getAttributes();
2211+
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
2212+
bool anyArgAttrs = false;
2213+
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
2214+
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
2215+
if (llvmArgAttrsSet.back().hasAttributes())
2216+
anyArgAttrs = true;
2217+
}
2218+
auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
2219+
SmallVector<Attribute> attrs;
2220+
for (auto &dict : dictAttrs)
2221+
attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
2222+
return builder.getArrayAttr(attrs);
2223+
};
2224+
if (anyArgAttrs) {
2225+
SmallVector<DictionaryAttr> argAttrs;
2226+
for (auto &llvmArgAttrs : llvmArgAttrsSet)
2227+
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
2228+
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
2229+
}
2230+
2231+
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
2232+
if (!llvmResAttr.hasAttributes())
2233+
return;
2234+
DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
2235+
callOp.setResAttrsAttr(getArrayAttr({resAttrs}));
2236+
}
2237+
22022238
template <typename Op>
22032239
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
22042240
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));

0 commit comments

Comments
 (0)