Skip to content

[MLIR][LLVMIR] llvm.call_intrinsic: support operand/result attributes #129640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2107,18 +2107,14 @@ def LLVM_CallIntrinsicOp
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
OptionalAttr<ArrayAttr>:$op_bundle_tags);
OptionalAttr<ArrayAttr>:$op_bundle_tags,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
let results = (outs Optional<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
}];
let assemblyFormat = [{
$intrin `(` $args `)`
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
$op_bundle_tags)^ )?
`:` functional-type($args, $results)
attr-dict
}];
let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "StringAttr":$intrin, "ValueRange":$args)>,
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ class ModuleImport {
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);

/// Converts the parameter and result attributes in `argsAttr` and `resAttr`
/// and add them to the `callOp`.
void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr,
ArrayAttr &resAttr, OpBuilder &builder);

private:
/// Clears the accumulated state before processing a new region.
void clearRegionState() {
Expand Down Expand Up @@ -350,7 +355,8 @@ class ModuleImport {
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
OpBuilder &builder);
/// Converts the parameter and result attributes attached to `call` and adds
/// them to the `callOp`.
/// them to the `callOp`. Implemented in terms of the the public definition of
/// convertParameterAttributes.
void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this function would be a good use case for a ArgumentAndResultAttributeInterface.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean In the future or as part of this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is definitely a separate PR.

OpBuilder &builder);
/// Converts the attributes attached to `inst` and adds them to the `op`.
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class ModuleTranslation {

/// Translates parameter attributes of a call and adds them to the returned
/// AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder> convertParameterAttrs(CallOpInterface callOp,
FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
DictionaryAttr paramAttrs);

/// Gets the named metadata in the LLVM IR module being constructed, creating
Expand Down
105 changes: 101 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3547,30 +3547,127 @@ void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
FastmathFlagsAttr{},
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::StringAttr intrin, mlir::ValueRange args,
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
fastMathFlags,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::Type resultType, mlir::StringAttr intrin,
mlir::ValueRange args) {
build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
mlir::TypeRange resultTypes,
mlir::StringAttr intrin, mlir::ValueRange args,
mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
build(builder, state, resultTypes, intrin, args, fastMathFlags,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{});
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
/*res_attrs=*/{});
}

ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr intrinAttr;
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
SmallVector<SmallVector<Type>> opBundleOperandTypes;
ArrayAttr opBundleTags;

// Parse intrinsic name.
if (parser.parseCustomAttributeWithFallback(
intrinAttr, parser.getBuilder().getType<NoneType>()))
return failure();
result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
intrinAttr);

if (parser.parseLParen())
return failure();

// Parse the function arguments.
if (parser.parseOperandList(operands))
return mlir::failure();

if (parser.parseRParen())
return mlir::failure();

// Handle bundles.
SMLoc opBundlesLoc = parser.getCurrentLocation();
if (std::optional<ParseResult> result = parseOpBundles(
parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
result && failed(*result))
return failure();
if (opBundleTags && !opBundleTags.empty())
result.addAttribute(
CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
opBundleTags);

if (parser.parseOptionalAttrDict(result.attributes))
return mlir::failure();

SmallVector<DictionaryAttr> argAttrs;
SmallVector<DictionaryAttr> resultAttrs;
if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
operands, argAttrs, resultAttrs))
return failure();
call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, argAttrs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));

if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
opBundleOperandTypes,
getOpBundleSizesAttrName(result.name)))
return failure();

int32_t numOpBundleOperands = 0;
for (const auto &operands : opBundleOperands)
numOpBundleOperands += operands.size();

result.addAttribute(
CallIntrinsicOp::getOperandSegmentSizeAttr(),
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(operands.size()), numOpBundleOperands}));

return mlir::success();
}

void CallIntrinsicOp::print(OpAsmPrinter &p) {
p << ' ';
p.printAttributeWithoutType(getIntrinAttr());

OperandRange args = getArgs();
p << "(" << args << ")";

// Operand bundles.
if (!getOpBundleOperands().empty()) {
p << ' ';
printOpBundles(p, *this, getOpBundleOperands(),
getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
}

p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getOperandSegmentSizesAttrName(),
getOpBundleSizesAttrName(), getIntrinAttrName(),
getOpBundleTagsAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName()});

p << " : ";

// Reconstruct the MLIR function type from operand and result types.
call_interface_impl::printFunctionSignature(
p, args.getTypes(), getArgAttrsAttr(),
/*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}

//===----------------------------------------------------------------------===//
Expand Down
79 changes: 46 additions & 33 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}

static LogicalResult
convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
ArrayAttr resAttrsArray, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
if (argAttrsArray) {
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
!argAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(loc, argAttrs);
if (failed(attrBuilder))
return failure();
call->addParamAttrs(argIdx, *attrBuilder);
}
}
}

if (resAttrsArray && resAttrsArray.size() > 0) {
if (resAttrsArray.size() != 1)
return mlir::emitError(loc, "llvm.func cannot have multiple results");
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
!resAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(loc, resAttrs);
if (failed(attrBuilder))
return failure();
call->addRetAttrs(*attrBuilder);
}
}
return success();
}

static LogicalResult
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
return convertParameterAndResultAttrs(
callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
moduleTranslation);
}

/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
Expand Down Expand Up @@ -201,6 +241,12 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
fn, moduleTranslation.lookupValues(op.getArgs()),
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));

if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
op.getResAttrsAttr(), inst,
moduleTranslation)))
return failure();

if (op.getNumResults() == 1)
moduleTranslation.mapValue(op->getResults().front()) = inst;
return success();
Expand All @@ -224,39 +270,6 @@ static void convertLinkerOptionsOp(ArrayAttr options,
linkerMDNode->addOperand(listMDNode);
}

static LogicalResult
convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
LLVM::ModuleTranslation &moduleTranslation) {
if (ArrayAttr argAttrsArray = callOp.getArgAttrsAttr()) {
for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
!argAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(callOp, argAttrs);
if (failed(attrBuilder))
return failure();
call->addParamAttrs(argIdx, *attrBuilder);
}
}
}

ArrayAttr resAttrsArray = callOp.getResAttrsAttr();
if (resAttrsArray && resAttrsArray.size() > 0) {
if (resAttrsArray.size() != 1)
return mlir::emitError(callOp.getLoc(),
"llvm.func cannot have multiple results");
if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
!resAttrs.empty()) {
FailureOr<llvm::AttrBuilder> attrBuilder =
moduleTranslation.convertParameterAttrs(callOp, resAttrs);
if (failed(attrBuilder))
return failure();
call->addRetAttrs(*attrBuilder);
}
}
return success();
}

static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(

moduleImport.setFastmathFlagsAttr(inst, op);

ArrayAttr argsAttr, resAttr;
moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
op.setArgAttrsAttr(argsAttr);
op.setResAttrsAttr(resAttr);

// Update importer tracking of results.
unsigned numRes = op.getNumResults();
if (numRes == 1)
Expand Down
16 changes: 13 additions & 3 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,7 +2213,8 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
}

void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
ArrayAttr &argsAttr,
ArrayAttr &resAttr,
OpBuilder &builder) {
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
Expand All @@ -2233,14 +2234,23 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
callOp.setArgAttrsAttr(getArrayAttr(argAttrs));
argsAttr = getArrayAttr(argAttrs);
}

llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
callOp.setResAttrsAttr(getArrayAttr({resAttrs}));
resAttr = getArrayAttr({resAttrs});
}

void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
CallOpInterface callOp,
OpBuilder &builder) {
ArrayAttr argsAttr, resAttr;
convertParameterAttributes(call, argsAttr, resAttr, builder);
callOp.setArgAttrsAttr(argsAttr);
callOp.setResAttrsAttr(resAttr);
}

template <typename Op>
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,10 +1696,9 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
}

FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(CallOpInterface callOp,
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
Location loc = callOp.getLoc();
auto attrNameToKindMapping = getAttrNameToKindMapping();

for (auto namedAttr : paramAttrs) {
Expand Down
19 changes: 14 additions & 5 deletions mlir/test/Dialect/LLVMIR/call-intrin.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
llvm.func @round_sse41() -> vector<4xf32> {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) {fastmathFlags = #llvm.fastmath<reassoc>} : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32>
llvm.return %res: vector<4xf32>
}

Expand All @@ -19,7 +19,7 @@ llvm.func @round_sse41() -> vector<4xf32> {
// CHECK: }
llvm.func @round_overloaded() -> f32 {
%0 = llvm.mlir.constant(1.0 : f32) : f32
%res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
%res = llvm.call_intrinsic "llvm.round"(%0) {} : (f32) -> f32
llvm.return %res: f32
}

Expand All @@ -34,7 +34,7 @@ llvm.func @lifetime_start() {
%0 = llvm.mlir.constant(4 : i64) : i64
%1 = llvm.mlir.constant(1 : i8) : i8
%2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) {} : (i64, !llvm.ptr) -> ()
llvm.return
}

Expand Down Expand Up @@ -64,7 +64,7 @@ llvm.func @bad_types() {
%0 = llvm.mlir.constant(1 : i8) : i8
// expected-error@below {{call intrinsic signature i8 (i8) to overloaded intrinsic "llvm.round" does not match any of the overloads}}
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
llvm.call_intrinsic "llvm.round"(%0) {} : (i8) -> i8
llvm.return
}

Expand Down Expand Up @@ -102,6 +102,15 @@ llvm.func @bad_args() {
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
// expected-error @below {{intrinsic call operand #2 has type i64 but "llvm.x86.sse41.round.ss" expects i32}}
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) {fastmathFlags = #llvm.fastmath<reassoc>} : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32>
llvm.return
}

// -----

// CHECK-LABEL: intrinsic_call_arg_attrs
llvm.func @intrinsic_call_arg_attrs(%arg0: i32) -> i32 {
// CHECK: call i32 @llvm.riscv.sha256sig0(i32 signext %{{.*}})
%0 = llvm.call_intrinsic "llvm.riscv.sha256sig0"(%arg0) : (i32 {llvm.signext}) -> (i32)
llvm.return %0 : i32
}
Loading