Skip to content

Commit f062d78

Browse files
committed
[mlir][LLVM] add argument and result attributes to llvm.call
1 parent 9216419 commit f062d78

File tree

11 files changed

+220
-40
lines changed

11 files changed

+220
-40
lines changed

llvm/include/llvm/IR/InstrTypes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,11 @@ class CallBase : public Instruction {
14901490
Attrs = Attrs.addRetAttribute(getContext(), Attr);
14911491
}
14921492

1493+
/// Adds attributes to the return value.
1494+
void addRetAttrs(const AttrBuilder &B) {
1495+
Attrs = Attrs.addRetAttributes(getContext(), B);
1496+
}
1497+
14931498
/// Adds the attribute to the indicated argument
14941499
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) {
14951500
assert(ArgNo < arg_size() && "Out of bounds");
@@ -1502,6 +1507,12 @@ class CallBase : public Instruction {
15021507
Attrs = Attrs.addParamAttribute(getContext(), ArgNo, Attr);
15031508
}
15041509

1510+
/// Adds attributes to the indicated argument
1511+
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B) {
1512+
assert(ArgNo < arg_size() && "Out of bounds");
1513+
Attrs = Attrs.addParamAttributes(getContext(), ArgNo, B);
1514+
}
1515+
15051516
/// removes the attribute from the list of attributes.
15061517
void removeAttributeAtIndex(unsigned i, Attribute::AttrKind Kind) {
15071518
Attrs = Attrs.removeAttributeAtIndex(getContext(), i, Kind);

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,18 @@ class ModuleImport {
335335
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
336336
/// Returns the callee name, or an empty symbol if the call is not direct.
337337
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
338-
/// Converts the parameter attributes attached to `func` and adds them to
339-
/// the `funcOp`.
338+
/// Converts the parameter and result attributes attached to `func` and adds
339+
/// them to the `funcOp`.
340340
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
341341
OpBuilder &builder);
342342
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
343343
/// DictionaryAttr for the LLVM dialect.
344344
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
345345
OpBuilder &builder);
346+
/// Converts the parameter and result attributes attached to `call` and adds
347+
/// them to the `callOp`.
348+
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
349+
OpBuilder &builder);
346350
/// Converts the attributes attached to `inst` and adds them to the `op`.
347351
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
348352
/// Converts the attributes attached to `inst` and adds them to the `op`.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ class ModuleTranslation {
228228
/*recordInsertions=*/false);
229229
}
230230

231+
/// Translates parameter attributes of a call and adds them to the returned
232+
/// AttrBuilder. Returns failure if any of the translations failed.
233+
FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOp callOp, int argIdx,
234+
DictionaryAttr paramAttrs);
235+
231236
/// Gets the named metadata in the LLVM IR module being constructed, creating
232237
/// it if it does not exist.
233238
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
@@ -346,8 +351,8 @@ class ModuleTranslation {
346351
convertDialectAttributes(Operation *op,
347352
ArrayRef<llvm::Instruction *> instructions);
348353

349-
/// Translates parameter attributes and adds them to the returned AttrBuilder.
350-
/// Returns failure if any of the translations failed.
354+
/// Translates parameter attributes of a function and adds them to the
355+
/// returned AttrBuilder. Returns failure if any of the translations failed.
351356
FailureOr<llvm::AttrBuilder>
352357
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
353358

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,55 +1335,66 @@ void CallOp::print(OpAsmPrinter &p) {
13351335
getVarCalleeTypeAttrName(), getCConvAttrName(),
13361336
getOperandSegmentSizesAttrName(),
13371337
getOpBundleSizesAttrName(),
1338-
getOpBundleTagsAttrName()});
1338+
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1339+
getResAttrsAttrName()});
13391340

13401341
p << " : ";
13411342
if (!isDirect)
13421343
p << getOperand(0).getType() << ", ";
13431344

13441345
// Reconstruct the function MLIR function type from operand and result types.
1345-
p.printFunctionalType(args.getTypes(), getResultTypes());
1346+
call_interface_impl::printFunctionSignature(
1347+
p, args.getTypes(), getArgAttrsAttr(),
1348+
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
13461349
}
13471350

13481351
/// Parses the type of a call operation and resolves the operands if the parsing
13491352
/// succeeds. Returns failure otherwise.
13501353
static ParseResult parseCallTypeAndResolveOperands(
13511354
OpAsmParser &parser, OperationState &result, bool isDirect,
1352-
ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1355+
ArrayRef<OpAsmParser::UnresolvedOperand> operands,
1356+
SmallVectorImpl<DictionaryAttr> &argAttrs,
1357+
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
13531358
SMLoc trailingTypesLoc = parser.getCurrentLocation();
13541359
SmallVector<Type> types;
1355-
if (parser.parseColonTypeList(types))
1360+
if (parser.parseColon())
13561361
return failure();
1357-
1358-
if (isDirect && types.size() != 1)
1359-
return parser.emitError(trailingTypesLoc,
1360-
"expected direct call to have 1 trailing type");
1361-
if (!isDirect && types.size() != 2)
1362-
return parser.emitError(trailingTypesLoc,
1363-
"expected indirect call to have 2 trailing types");
1364-
1365-
auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
1366-
if (!funcType)
1362+
if (!isDirect) {
1363+
types.emplace_back();
1364+
if (parser.parseType(types.back()))
1365+
return failure();
1366+
if (parser.parseOptionalComma())
1367+
return parser.emitError(
1368+
trailingTypesLoc, "expected indirect call to have 2 trailing types");
1369+
}
1370+
SmallVector<Type> argTypes;
1371+
SmallVector<Type> resTypes;
1372+
if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1373+
resTypes, resultAttrs)) {
1374+
if (isDirect)
1375+
return parser.emitError(trailingTypesLoc,
1376+
"expected direct call to have 1 trailing types");
13671377
return parser.emitError(trailingTypesLoc,
13681378
"expected trailing function type");
1369-
if (funcType.getNumResults() > 1)
1379+
}
1380+
1381+
if (resTypes.size() > 1)
13701382
return parser.emitError(trailingTypesLoc,
13711383
"expected function with 0 or 1 result");
1372-
if (funcType.getNumResults() == 1 &&
1373-
llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
1384+
if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
13741385
return parser.emitError(trailingTypesLoc,
13751386
"expected a non-void result type");
13761387

13771388
// The head element of the types list matches the callee type for
13781389
// indirect calls, while the types list is emtpy for direct calls.
13791390
// Append the function input types to resolve the call operation
13801391
// operands.
1381-
llvm::append_range(types, funcType.getInputs());
1392+
llvm::append_range(types, argTypes);
13821393
if (parser.resolveOperands(operands, types, parser.getNameLoc(),
13831394
result.operands))
13841395
return failure();
1385-
if (funcType.getNumResults() != 0)
1386-
result.addTypes(funcType.getResults());
1396+
if (resTypes.size() != 0)
1397+
result.addTypes(resTypes);
13871398

13881399
return success();
13891400
}
@@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
14971508
return failure();
14981509

14991510
// Parse the trailing type list and resolve the operands.
1500-
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1511+
SmallVector<DictionaryAttr> argAttrs;
1512+
SmallVector<DictionaryAttr> resultAttrs;
1513+
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1514+
argAttrs, resultAttrs))
15011515
return failure();
1516+
call_interface_impl::addArgAndResultAttrs(
1517+
parser.getBuilder(), result, argAttrs, resultAttrs,
1518+
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
15021519
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
15031520
opBundleOperandTypes,
15041521
getOpBundleSizesAttrName(result.name)))
@@ -1721,7 +1738,10 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
17211738
return failure();
17221739

17231740
// Parse the trailing type list and resolve the function operands.
1724-
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1741+
SmallVector<DictionaryAttr> argAttrs;
1742+
SmallVector<DictionaryAttr> resultAttrs;
1743+
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1744+
argAttrs, resultAttrs))
17251745
return failure();
17261746
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
17271747
opBundleOperandTypes,

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,27 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
265265
if (callOp.getWillReturnAttr())
266266
call->addFnAttr(llvm::Attribute::WillReturn);
267267

268+
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr())
269+
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
270+
if (auto argAttrs = llvm::cast<DictionaryAttr>(argAttrsAttr)) {
271+
FailureOr<llvm::AttrBuilder> attrBuilder =
272+
moduleTranslation.convertParameterAttrs(callOp, argIdx, argAttrs);
273+
if (failed(attrBuilder))
274+
return failure();
275+
call->addParamAttrs(argIdx, *attrBuilder);
276+
}
277+
}
278+
279+
ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
280+
if (resAttrsArray && resAttrsArray.size() == 1)
281+
if (auto resAttrs = llvm::cast<DictionaryAttr>(resAttrsArray[0])) {
282+
FailureOr<llvm::AttrBuilder> attrBuilder =
283+
moduleTranslation.convertParameterAttrs(callOp, -1, resAttrs);
284+
if (failed(attrBuilder))
285+
return failure();
286+
call->addRetAttrs(*attrBuilder);
287+
}
288+
268289
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
269290
llvm::MemoryEffects memEffects =
270291
llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,6 +1706,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17061706
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
17071707
if (failed(convertCallAttributes(callInst, callOp)))
17081708
return failure();
1709+
// Handle parameter and result attributes.
1710+
convertParameterAttributes(callInst, callOp, builder);
17091711
return callOp.getOperation();
17101712
}();
17111713

@@ -2149,6 +2151,38 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
21492151
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
21502152
}
21512153

2154+
void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
2155+
CallOpInterface callOp,
2156+
OpBuilder &builder) {
2157+
auto llvmAttrs = call->getAttributes();
2158+
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
2159+
bool anyArgAttrs = false;
2160+
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
2161+
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
2162+
if (llvmArgAttrsSet.back().hasAttributes())
2163+
anyArgAttrs = true;
2164+
}
2165+
auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
2166+
SmallVector<Attribute> attrs;
2167+
for (auto &dict : dictAttrs)
2168+
attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
2169+
return builder.getArrayAttr(attrs);
2170+
};
2171+
if (anyArgAttrs) {
2172+
SmallVector<DictionaryAttr> argAttrs;
2173+
for (auto &llvmArgAttrs : llvmArgAttrsSet)
2174+
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
2175+
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
2176+
}
2177+
2178+
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
2179+
if (!llvmResAttr.hasAttributes())
2180+
return;
2181+
SmallVector<DictionaryAttr, 1> resAttrs;
2182+
resAttrs.emplace_back(convertParameterAttribute(llvmResAttr, builder));
2183+
callOp.setResAttrsAttr(getArrayAttr(resAttrs));
2184+
}
2185+
21522186
template <typename Op>
21532187
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
21542188
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,26 @@ static void convertFunctionKernelAttributes(LLVMFuncOp func,
15631563
}
15641564
}
15651565

1566+
static void convertParameterAttr(llvm::AttrBuilder &attrBuilder,
1567+
llvm::Attribute::AttrKind llvmKind,
1568+
NamedAttribute namedAttr,
1569+
ModuleTranslation &moduleTranslation) {
1570+
llvm::TypeSwitch<Attribute>(namedAttr.getValue())
1571+
.Case<TypeAttr>([&](auto typeAttr) {
1572+
attrBuilder.addTypeAttr(
1573+
llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
1574+
})
1575+
.Case<IntegerAttr>([&](auto intAttr) {
1576+
attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1577+
})
1578+
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
1579+
.Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
1580+
attrBuilder.addConstantRangeAttr(
1581+
llvmKind,
1582+
llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
1583+
});
1584+
}
1585+
15661586
FailureOr<llvm::AttrBuilder>
15671587
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
15681588
DictionaryAttr paramAttrs) {
@@ -1573,20 +1593,7 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
15731593
auto it = attrNameToKindMapping.find(namedAttr.getName());
15741594
if (it != attrNameToKindMapping.end()) {
15751595
llvm::Attribute::AttrKind llvmKind = it->second;
1576-
1577-
llvm::TypeSwitch<Attribute>(namedAttr.getValue())
1578-
.Case<TypeAttr>([&](auto typeAttr) {
1579-
attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
1580-
})
1581-
.Case<IntegerAttr>([&](auto intAttr) {
1582-
attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1583-
})
1584-
.Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); })
1585-
.Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
1586-
attrBuilder.addConstantRangeAttr(
1587-
llvmKind, llvm::ConstantRange(rangeAttr.getLower(),
1588-
rangeAttr.getUpper()));
1589-
});
1596+
convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
15901597
} else if (namedAttr.getNameDialect()) {
15911598
if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
15921599
return failure();
@@ -1596,6 +1603,23 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
15961603
return attrBuilder;
15971604
}
15981605

1606+
FailureOr<llvm::AttrBuilder>
1607+
ModuleTranslation::convertParameterAttrs(CallOp, int argIdx,
1608+
DictionaryAttr paramAttrs) {
1609+
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1610+
auto attrNameToKindMapping = getAttrNameToKindMapping();
1611+
1612+
for (auto namedAttr : paramAttrs) {
1613+
auto it = attrNameToKindMapping.find(namedAttr.getName());
1614+
if (it != attrNameToKindMapping.end()) {
1615+
llvm::Attribute::AttrKind llvmKind = it->second;
1616+
convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this);
1617+
}
1618+
}
1619+
1620+
return attrBuilder;
1621+
}
1622+
15991623
LogicalResult ModuleTranslation::convertFunctionSignatures() {
16001624
// Declare all functions first because there may be function calls that form a
16011625
// call graph with cycles, or global initializers that reference functions.

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ func.func @call_missing_ptr_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
235235
func.func private @standard_func_callee()
236236

237237
func.func @call_missing_ptr_type(%arg : i8) {
238+
// expected-error@+2 {{expected '('}}
238239
// expected-error@+1 {{expected direct call to have 1 trailing type}}
239240
llvm.call @standard_func_callee(%arg) : !llvm.ptr, (i8) -> (i8)
240241
llvm.return
@@ -251,6 +252,7 @@ func.func @call_non_pointer_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
251252
// -----
252253

253254
func.func @call_non_function_type(%callee : !llvm.ptr, %arg : i8) {
255+
// expected-error@+2 {{expected '('}}
254256
// expected-error@+1 {{expected trailing function type}}
255257
llvm.call %callee(%arg) : !llvm.ptr, !llvm.func<i8 (i8)>
256258
llvm.return

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,3 +941,23 @@ llvm.func @test_assume_intr_with_opbundles(%arg0 : !llvm.ptr) {
941941
llvm.intr.assume %0 ["tag1"(%1, %2 : i32, i32), "tag2"(%3 : i32)] : i1
942942
llvm.return
943943
}
944+
945+
llvm.func @somefunc(i32, !llvm.ptr)
946+
947+
// CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
948+
// CHECK-SAME: %[[VAL_0:.*]]: i32,
949+
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr)
950+
llvm.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !llvm.ptr) {
951+
// CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
952+
llvm.call @somefunc(%arg0, %arg1) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
953+
llvm.return
954+
}
955+
956+
// CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
957+
// CHECK-SAME: %[[VAL_0:.*]]: i16,
958+
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr
959+
llvm.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: !llvm.ptr) -> i16 {
960+
// CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
961+
%0 = llvm.call tail %arg1(%arg0) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
962+
llvm.return %0 : i16
963+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: mlir-translate -import-llvm %s | FileCheck %s
2+
3+
; CHECK-LABEL: llvm.func @somefunc(i32, !llvm.ptr)
4+
declare void @somefunc(i32, ptr)
5+
6+
; CHECK-LABEL: llvm.func @test_call_arg_attrs_direct(
7+
; CHECK-SAME: %[[VAL_0:.*]]: i32,
8+
; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr)
9+
define void @test_call_arg_attrs_direct(i32 %0, ptr %1) {
10+
; CHECK: llvm.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
11+
call void @somefunc(i32 %0, ptr byval(i64) %1)
12+
ret void
13+
}
14+
15+
; CHECK-LABEL: llvm.func @test_call_arg_attrs_indirect(
16+
; CHECK-SAME: %[[VAL_0:.*]]: i16,
17+
; CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr
18+
define i16 @test_call_arg_attrs_indirect(i16 %0, ptr %1) {
19+
; CHECK: llvm.call tail %[[VAL_1]](%[[VAL_0]]) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
20+
%3 = tail call signext i16 %1(i16 noundef signext %0)
21+
ret i16 %3
22+
}

0 commit comments

Comments
 (0)