Skip to content

Commit 2bbe30b

Browse files
bcardosolopesgysit
andauthored
[MLIR][LLVMIR] llvm.call_intrinsic: support operand/result attributes (#129640)
Basically catch up with llvm.call and add support for translate and import to LLVM IR. This PR is split into two commits in case it's easier to review the refactoring part, which comes first (happy to split the PR if necessary). --------- Co-authored-by: Tobias Gysi <[email protected]>
1 parent 35622a9 commit 2bbe30b

File tree

12 files changed

+226
-57
lines changed

12 files changed

+226
-57
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,18 +2107,14 @@ def LLVM_CallIntrinsicOp
21072107
VariadicOfVariadic<LLVM_Type,
21082108
"op_bundle_sizes">:$op_bundle_operands,
21092109
DenseI32ArrayAttr:$op_bundle_sizes,
2110-
OptionalAttr<ArrayAttr>:$op_bundle_tags);
2110+
OptionalAttr<ArrayAttr>:$op_bundle_tags,
2111+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2112+
OptionalAttr<DictArrayAttr>:$res_attrs);
21112113
let results = (outs Optional<LLVM_Type>:$results);
21122114
let llvmBuilder = [{
21132115
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
21142116
}];
2115-
let assemblyFormat = [{
2116-
$intrin `(` $args `)`
2117-
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
2118-
$op_bundle_tags)^ )?
2119-
`:` functional-type($args, $results)
2120-
attr-dict
2121-
}];
2117+
let hasCustomAssemblyFormat = 1;
21222118

21232119
let builders = [
21242120
OpBuilder<(ins "StringAttr":$intrin, "ValueRange":$args)>,

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ class ModuleImport {
272272
SmallVectorImpl<Value> &valuesOut,
273273
SmallVectorImpl<NamedAttribute> &attrsOut);
274274

275+
/// Converts the parameter and result attributes in `argsAttr` and `resAttr`
276+
/// and add them to the `callOp`.
277+
void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr,
278+
ArrayAttr &resAttr, OpBuilder &builder);
279+
275280
private:
276281
/// Clears the accumulated state before processing a new region.
277282
void clearRegionState() {
@@ -350,7 +355,8 @@ class ModuleImport {
350355
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
351356
OpBuilder &builder);
352357
/// Converts the parameter and result attributes attached to `call` and adds
353-
/// them to the `callOp`.
358+
/// them to the `callOp`. Implemented in terms of the the public definition of
359+
/// convertParameterAttributes.
354360
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
355361
OpBuilder &builder);
356362
/// Converts the attributes attached to `inst` and adds them to the `op`.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class ModuleTranslation {
237237

238238
/// Translates parameter attributes of a call and adds them to the returned
239239
/// AttrBuilder. Returns failure if any of the translations failed.
240-
FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOpInterface callOp,
240+
FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
241241
DictionaryAttr paramAttrs);
242242

243243
/// Gets the named metadata in the LLVM IR module being constructed, creating

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3547,30 +3547,127 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
35473547
mlir::StringAttr intrin, mlir::ValueRange args) {
35483548
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
35493549
FastmathFlagsAttr{},
3550-
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3550+
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3551+
/*res_attrs=*/{});
35513552
}
35523553

35533554
void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
35543555
mlir::StringAttr intrin, mlir::ValueRange args,
35553556
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
35563557
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
35573558
fastMathFlags,
3558-
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3559+
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3560+
/*res_attrs=*/{});
35593561
}
35603562

35613563
void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
35623564
mlir::Type resultType, mlir::StringAttr intrin,
35633565
mlir::ValueRange args) {
35643566
build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
3565-
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3567+
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3568+
/*res_attrs=*/{});
35663569
}
35673570

35683571
void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
35693572
mlir::TypeRange resultTypes,
35703573
mlir::StringAttr intrin, mlir::ValueRange args,
35713574
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
35723575
build(builder, state, resultTypes, intrin, args, fastMathFlags,
3573-
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
3576+
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3577+
/*res_attrs=*/{});
3578+
}
3579+
3580+
ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
3581+
OperationState &result) {
3582+
StringAttr intrinAttr;
3583+
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3584+
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
3585+
SmallVector<SmallVector<Type>> opBundleOperandTypes;
3586+
ArrayAttr opBundleTags;
3587+
3588+
// Parse intrinsic name.
3589+
if (parser.parseCustomAttributeWithFallback(
3590+
intrinAttr, parser.getBuilder().getType<NoneType>()))
3591+
return failure();
3592+
result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
3593+
intrinAttr);
3594+
3595+
if (parser.parseLParen())
3596+
return failure();
3597+
3598+
// Parse the function arguments.
3599+
if (parser.parseOperandList(operands))
3600+
return mlir::failure();
3601+
3602+
if (parser.parseRParen())
3603+
return mlir::failure();
3604+
3605+
// Handle bundles.
3606+
SMLoc opBundlesLoc = parser.getCurrentLocation();
3607+
if (std::optional<ParseResult> result = parseOpBundles(
3608+
parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
3609+
result && failed(*result))
3610+
return failure();
3611+
if (opBundleTags && !opBundleTags.empty())
3612+
result.addAttribute(
3613+
CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
3614+
opBundleTags);
3615+
3616+
if (parser.parseOptionalAttrDict(result.attributes))
3617+
return mlir::failure();
3618+
3619+
SmallVector<DictionaryAttr> argAttrs;
3620+
SmallVector<DictionaryAttr> resultAttrs;
3621+
if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
3622+
operands, argAttrs, resultAttrs))
3623+
return failure();
3624+
call_interface_impl::addArgAndResultAttrs(
3625+
parser.getBuilder(), result, argAttrs, resultAttrs,
3626+
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3627+
3628+
if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
3629+
opBundleOperandTypes,
3630+
getOpBundleSizesAttrName(result.name)))
3631+
return failure();
3632+
3633+
int32_t numOpBundleOperands = 0;
3634+
for (const auto &operands : opBundleOperands)
3635+
numOpBundleOperands += operands.size();
3636+
3637+
result.addAttribute(
3638+
CallIntrinsicOp::getOperandSegmentSizeAttr(),
3639+
parser.getBuilder().getDenseI32ArrayAttr(
3640+
{static_cast<int32_t>(operands.size()), numOpBundleOperands}));
3641+
3642+
return mlir::success();
3643+
}
3644+
3645+
void CallIntrinsicOp::print(OpAsmPrinter &p) {
3646+
p << ' ';
3647+
p.printAttributeWithoutType(getIntrinAttr());
3648+
3649+
OperandRange args = getArgs();
3650+
p << "(" << args << ")";
3651+
3652+
// Operand bundles.
3653+
if (!getOpBundleOperands().empty()) {
3654+
p << ' ';
3655+
printOpBundles(p, *this, getOpBundleOperands(),
3656+
getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
3657+
}
3658+
3659+
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
3660+
{getOperandSegmentSizesAttrName(),
3661+
getOpBundleSizesAttrName(), getIntrinAttrName(),
3662+
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
3663+
getResAttrsAttrName()});
3664+
3665+
p << " : ";
3666+
3667+
// Reconstruct the MLIR function type from operand and result types.
3668+
call_interface_impl::printFunctionSignature(
3669+
p, args.getTypes(), getArgAttrsAttr(),
3670+
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
35743671
}
35753672

35763673
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,46 @@ convertOperandBundles(OperandRangeRange bundleOperands,
135135
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
136136
}
137137

138+
static LogicalResult
139+
convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
140+
ArrayAttr resAttrsArray, llvm::CallBase *call,
141+
LLVM::ModuleTranslation &moduleTranslation) {
142+
if (argAttrsArray) {
143+
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
144+
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
145+
!argAttrs.empty()) {
146+
FailureOr<llvm::AttrBuilder> attrBuilder =
147+
moduleTranslation.convertParameterAttrs(loc, argAttrs);
148+
if (failed(attrBuilder))
149+
return failure();
150+
call->addParamAttrs(argIdx, *attrBuilder);
151+
}
152+
}
153+
}
154+
155+
if (resAttrsArray && resAttrsArray.size() > 0) {
156+
if (resAttrsArray.size() != 1)
157+
return mlir::emitError(loc, "llvm.func cannot have multiple results");
158+
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
159+
!resAttrs.empty()) {
160+
FailureOr<llvm::AttrBuilder> attrBuilder =
161+
moduleTranslation.convertParameterAttrs(loc, resAttrs);
162+
if (failed(attrBuilder))
163+
return failure();
164+
call->addRetAttrs(*attrBuilder);
165+
}
166+
}
167+
return success();
168+
}
169+
170+
static LogicalResult
171+
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
172+
LLVM::ModuleTranslation &moduleTranslation) {
173+
return convertParameterAndResultAttrs(
174+
callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
175+
moduleTranslation);
176+
}
177+
138178
/// Builder for LLVM_CallIntrinsicOp
139179
static LogicalResult
140180
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -201,6 +241,12 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
201241
fn, moduleTranslation.lookupValues(op.getArgs()),
202242
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
203243
moduleTranslation));
244+
245+
if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
246+
op.getResAttrsAttr(), inst,
247+
moduleTranslation)))
248+
return failure();
249+
204250
if (op.getNumResults() == 1)
205251
moduleTranslation.mapValue(op->getResults().front()) = inst;
206252
return success();
@@ -224,39 +270,6 @@ static void convertLinkerOptionsOp(ArrayAttr options,
224270
linkerMDNode->addOperand(listMDNode);
225271
}
226272

227-
static LogicalResult
228-
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
229-
LLVM::ModuleTranslation &moduleTranslation) {
230-
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
231-
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
232-
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
233-
!argAttrs.empty()) {
234-
FailureOr<llvm::AttrBuilder> attrBuilder =
235-
moduleTranslation.convertParameterAttrs(callOp, argAttrs);
236-
if (failed(attrBuilder))
237-
return failure();
238-
call->addParamAttrs(argIdx, *attrBuilder);
239-
}
240-
}
241-
}
242-
243-
ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
244-
if (resAttrsArray && resAttrsArray.size() > 0) {
245-
if (resAttrsArray.size() != 1)
246-
return mlir::emitError(callOp.getLoc(),
247-
"llvm.func cannot have multiple results");
248-
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
249-
!resAttrs.empty()) {
250-
FailureOr<llvm::AttrBuilder> attrBuilder =
251-
moduleTranslation.convertParameterAttrs(callOp, resAttrs);
252-
if (failed(attrBuilder))
253-
return failure();
254-
call->addRetAttrs(*attrBuilder);
255-
}
256-
}
257-
return success();
258-
}
259-
260273
static LogicalResult
261274
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
262275
LLVM::ModuleTranslation &moduleTranslation) {

mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
4545

4646
moduleImport.setFastmathFlagsAttr(inst, op);
4747

48+
ArrayAttr argsAttr, resAttr;
49+
moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
50+
op.setArgAttrsAttr(argsAttr);
51+
op.setResAttrsAttr(resAttr);
52+
4853
// Update importer tracking of results.
4954
unsigned numRes = op.getNumResults();
5055
if (numRes == 1)

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,7 +2213,8 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
22132213
}
22142214

22152215
void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
2216-
CallOpInterface callOp,
2216+
ArrayAttr &argsAttr,
2217+
ArrayAttr &resAttr,
22172218
OpBuilder &builder) {
22182219
llvm::AttributeList llvmAttrs = call->getAttributes();
22192220
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
@@ -2233,14 +2234,23 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
22332234
SmallVector<DictionaryAttr> argAttrs;
22342235
for (auto &llvmArgAttrs : llvmArgAttrsSet)
22352236
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
2236-
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
2237+
argsAttr = getArrayAttr(argAttrs);
22372238
}
22382239

22392240
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
22402241
if (!llvmResAttr.hasAttributes())
22412242
return;
22422243
DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
2243-
callOp.setResAttrsAttr(getArrayAttr({resAttrs}));
2244+
resAttr = getArrayAttr({resAttrs});
2245+
}
2246+
2247+
void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
2248+
CallOpInterface callOp,
2249+
OpBuilder &builder) {
2250+
ArrayAttr argsAttr, resAttr;
2251+
convertParameterAttributes(call, argsAttr, resAttr, builder);
2252+
callOp.setArgAttrsAttr(argsAttr);
2253+
callOp.setResAttrsAttr(resAttr);
22442254
}
22452255

22462256
template <typename Op>

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,10 +1701,9 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
17011701
}
17021702

17031703
FailureOr<llvm::AttrBuilder>
1704-
ModuleTranslation::convertParameterAttrs(CallOpInterface callOp,
1704+
ModuleTranslation::convertParameterAttrs(Location loc,
17051705
DictionaryAttr paramAttrs) {
17061706
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1707-
Location loc = callOp.getLoc();
17081707
auto attrNameToKindMapping = getAttrNameToKindMapping();
17091708

17101709
for (auto namedAttr : paramAttrs) {

mlir/test/Dialect/LLVMIR/call-intrin.mlir

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
llvm.func @round_sse41() -> vector<4xf32> {
88
%0 = llvm.mlir.constant(1 : i32) : i32
99
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
10-
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
10+
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) {fastmathFlags = #llvm.fastmath<reassoc>} : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32>
1111
llvm.return %res: vector<4xf32>
1212
}
1313

@@ -19,7 +19,7 @@ llvm.func @round_sse41() -> vector<4xf32> {
1919
// CHECK: }
2020
llvm.func @round_overloaded() -> f32 {
2121
%0 = llvm.mlir.constant(1.0 : f32) : f32
22-
%res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
22+
%res = llvm.call_intrinsic "llvm.round"(%0) {} : (f32) -> f32
2323
llvm.return %res: f32
2424
}
2525

@@ -34,7 +34,7 @@ llvm.func @lifetime_start() {
3434
%0 = llvm.mlir.constant(4 : i64) : i64
3535
%1 = llvm.mlir.constant(1 : i8) : i8
3636
%2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
37-
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
37+
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) {} : (i64, !llvm.ptr) -> ()
3838
llvm.return
3939
}
4040

@@ -64,7 +64,7 @@ llvm.func @bad_types() {
6464
%0 = llvm.mlir.constant(1 : i8) : i8
6565
// expected-error@below {{call intrinsic signature i8 (i8) to overloaded intrinsic "llvm.round" does not match any of the overloads}}
6666
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
67-
llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
67+
llvm.call_intrinsic "llvm.round"(%0) {} : (i8) -> i8
6868
llvm.return
6969
}
7070

@@ -102,6 +102,15 @@ llvm.func @bad_args() {
102102
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
103103
// expected-error @below {{intrinsic call operand #2 has type i64 but "llvm.x86.sse41.round.ss" expects i32}}
104104
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
105-
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
105+
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) {fastmathFlags = #llvm.fastmath<reassoc>} : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32>
106106
llvm.return
107107
}
108+
109+
// -----
110+
111+
// CHECK-LABEL: intrinsic_call_arg_attrs
112+
llvm.func @intrinsic_call_arg_attrs(%arg0: i32) -> i32 {
113+
// CHECK: call i32 @llvm.riscv.sha256sig0(i32 signext %{{.*}})
114+
%0 = llvm.call_intrinsic "llvm.riscv.sha256sig0"(%arg0) : (i32 {llvm.signext}) -> (i32)
115+
llvm.return %0 : i32
116+
}

0 commit comments

Comments
 (0)