-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][LLVM] Add operand bundle support #108933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-codegen @llvm/pr-subscribers-mlir Author: Sirui Mu (Lancern) ChangesThis PR adds LLVM operand bundle support to MLIR LLVM dialect. It affects these 3 operations related to making function calls: This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand Patch is 37.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108933.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2da45eba77655b..67d43e4d2e0657 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1210,4 +1210,33 @@ def WorkgroupAttributionAttr
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
}
+//===----------------------------------------------------------------------===//
+// OperandBundleAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_OperandBundleAttr : LLVM_Attr<"OperandBundle", "opbundle"> {
+ let summary = "Operand bundle information";
+ let description = [{
+ Provide information about a single operand bundle. Each operand bundle has a
+ string tag together with various number of SSA value uses. The SSA values
+ are specified through indices into the operation's operand bundle operands.
+ }];
+
+ let parameters = (ins "StringAttr":$tag,
+ OptionalArrayRefParameter<"uint32_t">:$argIndices);
+ let assemblyFormat = [{
+ `<` $tag (`,` $argIndices^)? `>`
+ }];
+}
+
+def LLVM_OperandBundlesAttr : LLVM_Attr<"OperandBundles", "opbundles"> {
+ let summary = "A list of operand bundle attributes";
+ let description = "A list of operand bundle attributes";
+
+ let parameters = (ins ArrayRefParameter<"OperandBundleAttr">:$bundles);
+ let assemblyFormat = [{
+ `<` $bundles `>`
+ }];
+}
+
#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d956d7f27f784d..b10a7bbad8eea2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -550,8 +550,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
+ Variadic<LLVM_Type>:$bundle_operands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
- DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+ DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+ OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
let results = (outs Optional<LLVM_Type>:$result);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
@@ -587,7 +589,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
//===----------------------------------------------------------------------===//
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
- [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -633,6 +636,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
+ Variadic<LLVM_Type>:$bundle_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
@@ -641,7 +645,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
OptionalAttr<UnitAttr>:$convergent,
OptionalAttr<UnitAttr>:$no_unwind,
- OptionalAttr<UnitAttr>:$will_return
+ OptionalAttr<UnitAttr>:$will_return,
+ OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles
);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
@@ -662,6 +667,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];
+ let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Returns the callee function type.
@@ -1875,21 +1881,28 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
def LLVM_CallIntrinsicOp
: LLVM_Op<"call_intrinsic",
- [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let summary = "Call to an LLVM intrinsic function.";
let description = [{
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
the MLIR function type of this op to determine which intrinsic to call.
}];
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
+ Variadic<LLVM_Type>:$bundle_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
- "{}">:$fastmathFlags);
+ "{}">:$fastmathFlags,
+ OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
let results = (outs Optional<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
}];
let assemblyFormat = [{
- $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
+ $intrin `(` $args `)`
+ ( `bundlearg` `(` $bundle_operands^ `)` )?
+ `:` functional-type($args, $results)
+ ( `,` `tuple` `<` type($bundle_operands)^ `>` )?
+ attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4c2e8682285c52..85ec6031dfee70 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promoted, callOp->getAttrs());
+ newOp->setAttr(newOp.getOperandSegmentSizesAttrName(),
+ rewriter.getDenseI32ArrayAttr(
+ {static_cast<int32_t>(promoted.size()), 0}));
+
SmallVector<Value, 4> results;
if (numResults < 2) {
// If < 2 results, packing did not do anything and we can just return.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ca786316324198..6bb8203da1898a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -837,8 +837,12 @@ class FunctionCallPattern
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (callOp.getNumResults() == 0) {
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
+ newOp->setAttr(
+ newOp.getOperandSegmentSizesAttrName(),
+ rewriter.getDenseI32ArrayAttr(
+ {static_cast<int32_t>(adaptor.getOperands().size()), 0}));
return success();
}
@@ -846,8 +850,12 @@ class FunctionCallPattern
auto dstType = typeConverter.convertType(callOp.getType(0));
if (!dstType)
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
+ newOp->setAttr(
+ newOp.getOperandSegmentSizesAttrName(),
+ rewriter.getDenseI32ArrayAttr(
+ {static_cast<int32_t>(adaptor.getOperands().size()), 0}));
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 205d7494d4378c..bb7df718ccad61 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -949,12 +949,14 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
assert(callee && "expected non-null callee in direct call builder");
build(builder, state, results,
- /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+ /*var_callee_type=*/nullptr, callee, args, /*bundle_operands=*/{},
+ /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+ /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+ /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -975,11 +977,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, args,
+ /*bundle_operands=*/{},
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
/*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr,
- /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+ /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundles=*/nullptr,
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
@@ -988,12 +991,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType),
- /*callee=*/nullptr, args,
+ /*callee=*/nullptr, args, /*bundle_operands=*/{},
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+ /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
@@ -1001,11 +1004,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
+ /*bundle_operands=*/{},
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
- /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+ /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}
CallInterfaceCallable CallOp::getCallableForCallee() {
@@ -1027,7 +1031,7 @@ void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
}
Operation::operand_range CallOp::getArgOperands() {
- return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+ return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange CallOp::getArgOperandsMutable() {
@@ -1100,6 +1104,38 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
return success();
}
+template <typename OpType>
+static LogicalResult verifyOperandBundleOperands(OpType &op) {
+ ValueRange opBundleOperands = op.getBundleOperands();
+ OperandBundlesAttr opBundles = op.getOpBundlesAttr();
+
+ if (!opBundles) {
+ if (!opBundleOperands.empty())
+ return op.emitError("expected operand bundles attribute");
+ return success();
+ }
+
+ DenseSet<uint32_t> seenOperandIdx;
+ for (OperandBundleAttr bundle : opBundles.getBundles()) {
+ for (uint32_t bundleOperandIdx : bundle.getArgIndices()) {
+ if (bundleOperandIdx >= opBundleOperands.size())
+ return op.emitError("operand bundle argument index ")
+ << bundleOperandIdx << " is out of range";
+ seenOperandIdx.insert(bundleOperandIdx);
+ }
+ }
+
+ for (uint32_t idx = 0; idx < opBundleOperands.size(); ++idx) {
+ if (!seenOperandIdx.contains(idx))
+ return op.emitError("operand bundle argument at index ")
+ << idx << " is not included in any operand bundles";
+ }
+
+ return success();
+}
+
+LogicalResult CallOp::verify() { return verifyOperandBundleOperands(*this); }
+
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(verifyCallOpVarCalleeType(*this)))
return failure();
@@ -1150,15 +1186,15 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Verify that the operand and result types match the callee.
if (!funcType.isVarArg() &&
- funcType.getNumParams() != (getNumOperands() - isIndirect))
+ funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
return emitOpError() << "incorrect number of operands ("
- << (getNumOperands() - isIndirect)
+ << (getCalleeOperands().size() - isIndirect)
<< ") for callee (expecting: "
<< funcType.getNumParams() << ")";
- if (funcType.getNumParams() > (getNumOperands() - isIndirect))
+ if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
return emitOpError() << "incorrect number of operands ("
- << (getNumOperands() - isIndirect)
+ << (getCalleeOperands().size() - isIndirect)
<< ") for varargs callee (expecting at least: "
<< funcType.getNumParams() << ")";
@@ -1208,16 +1244,24 @@ void CallOp::print(OpAsmPrinter &p) {
else
p << getOperand(0);
- auto args = getOperands().drop_front(isDirect ? 0 : 1);
+ auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
// Print the variadic callee type if the call is variadic.
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";
+ // Print the operand bundles, if any.
+ if (!getBundleOperands().empty()) {
+ p << " bundlearg(";
+ p.printOperands(getBundleOperands());
+ p << ")";
+ }
+
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getCalleeAttrName(), getTailCallKindAttrName(),
- getVarCalleeTypeAttrName(), getCConvAttrName()});
+ getVarCalleeTypeAttrName(), getCConvAttrName(),
+ getOperandSegmentSizesAttrName()});
p << " : ";
if (!isDirect)
@@ -1225,24 +1269,53 @@ void CallOp::print(OpAsmPrinter &p) {
// Reconstruct the function MLIR function type from operand and result types.
p.printFunctionalType(args.getTypes(), getResultTypes());
+
+ if (!getBundleOperands().empty()) {
+ SmallVector<Type> opBundleArgTypes;
+ opBundleArgTypes.reserve(getBundleOperands().size());
+ for (auto opBundleArg : getBundleOperands())
+ opBundleArgTypes.push_back(opBundleArg.getType());
+
+ p << ", tuple<";
+ llvm::interleaveComma(opBundleArgTypes, p);
+ p << ">";
+ }
}
/// Parses the type of a call operation and resolves the operands if the parsing
/// succeeds. Returns failure otherwise.
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
- ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+ ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+ ArrayRef<OpAsmParser::UnresolvedOperand> opBundleOperands) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
if (parser.parseColonTypeList(types))
return failure();
- if (isDirect && types.size() != 1)
- return parser.emitError(trailingTypesLoc,
- "expected direct call to have 1 trailing type");
- if (!isDirect && types.size() != 2)
+ if (isDirect && opBundleOperands.empty() && types.size() != 1)
+ return parser.emitError(
+ trailingTypesLoc,
+ "expected direct call without operand bundles to have 1 trailing type");
+ if (isDirect && !opBundleOperands.empty() && types.size() != 2)
+ return parser.emitError(
+ trailingTypesLoc,
+ "expected direct call with operand bundles to have 2 trailing types");
+ if (!isDirect && opBundleOperands.empty() && types.size() != 2)
return parser.emitError(trailingTypesLoc,
- "expected indirect call to have 2 trailing types");
+ "expected indirect call without operand bundles to "
+ "have 2 trailing types");
+ if (!isDirect && !opBundleOperands.empty() && types.size() != 3)
+ return parser.emitError(
+ trailingTypesLoc,
+ "expected indirect call with operand bundles to have 3 trailing types");
+
+ TupleType opBundleTypes;
+ if (!opBundleOperands.empty()) {
+ opBundleTypes = llvm::dyn_cast<TupleType>(types.pop_back_val());
+ if (!opBundleTypes)
+ return parser.emitError(trailingTypesLoc, "expected trailing tuple type");
+ }
auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
if (!funcType)
@@ -1267,6 +1340,12 @@ static ParseResult parseCallTypeAndResolveOperands(
if (funcType.getNumResults() != 0)
result.addTypes(funcType.getResults());
+ if (!opBundleOperands.empty()) {
+ if (parser.resolveOperands(opBundleOperands, opBundleTypes.getTypes(),
+ parser.getNameLoc(), result.operands))
+ return failure();
+ }
+
return success();
}
@@ -1288,7 +1367,9 @@ static ParseResult parseOptionalCallFuncPtr(
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-callee-type `)` )?
+// ( `bundlearg(` ssa-use-list `)` )?
// attribute-dict? `:` (type `,`)? function-type
+// (`,` `tuple` `<` type-list `>`)?
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
TypeAttr varCalleeType;
@@ -1333,11 +1414,25 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
+ SmallVector<OpAsmParser::UnresolvedOperand> opBundleOperands;
+ bool hasOpBundles = parser.parseOptionalKeyword("bundlearg").succeeded();
+ if (hasOpBundles &&
+ parser.parseOperandList(opBundleOperands, OpAsmParser::Delimiter::Paren))
+ return failure();
+
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the trailing type list and resolve the operands.
- return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+ if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+ opBundleOperands))
+ return failure();
+
+ result.addAttribute(CallOp::getOperandSegmentSizeAttr(),
+ parser.getBuilder().getDenseI32ArrayAttr(
+ {static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(opBundleOperands.size())}));
+ retu...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we cannot use VariadicOfVariadic
for bundle_operands
?
op_bundles
would then need to be a DenseI32ArrayAttr
, but ought to be kept up to sync automatically by all APIs I think.
Each operand bundle has a string tag associated with it, thus I have to invent the |
I believe this should result in less code and things that could go out of sync. Keeping the number of operands in the current With the |
@zero9178 I'll update to follow your proposed approach. |
bbb2bc2
to
7a7bc80
Compare
@zero9178 I've updated to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much! This is pretty much as I imagined conceptually.
I still have one more high-level design question about the syntax before I will do a more thorough pass over the actual code, though I did leave some comments in the code where I noticed issues.
7a7bc80
to
7a41798
Compare
@zero9178 Hi I've updated the PR to use the proposed syntax. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thank you so much! I added some nits regarding code-style and found oen case that I believe is a bug. Otherwise I think this will soon be good to go
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
I added some mostly minor comments. It seems like some verifier tests and probably also some roundtrip.mlir tests for printing and parsing could make sense?
If you have time to add support for the llvm.intr.assume that would be great.
I think mid term we also want to support operand bundles when importing from LLVM IR into MLIR LLVM dialect (ModuleImport.cpp). However, given the size of the PR it may make sense to do that in a later stage.
llvm.func @call_intrin_with_opbundle(%arg0 : !llvm.ptr) { | ||
%0 = llvm.mlir.constant(1 : i1) : i1 | ||
%1 = llvm.mlir.constant(16 : i32) : i32 | ||
llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also an "llvm.intr.assume" intrinsic that would benefit from operand bundle support. It is defined in LLVMIntrinsicOps.td. I wonder if this could be updated as well? AFAIK most intrinsics are nowadays represented using specific operations rather than with the generic llvm.call_intrinsic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make another PR for llvm.intr.assume
after this PR lands. After all this PR is about operand bundles :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is worth having tests in roundtrip.mlir since the printing and parsing is non-trivial here?
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once @gysit's comments are addressed and he is happy 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tests! LGTM
7a54369
to
4351398
Compare
Rebased onto the latest |
BTW I don't have commit access so if all reviewers are happy about this PR someone needs to land it. |
There seems to be a flang issue caused by producing an invalid |
Once CI is green one of us will happily click the button. As this is not your first contribution - and since we hope for many more :) - you can consider to apply for commit access (https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access) once this landed. |
Fixed these regressions in my local environment. Let's see if the CI passes.
Great, I'll apply for the commit access after this PR lands! |
The CI is green now. |
This PR adds LLVM [operand bundle](https://llvm.org/docs/LangRef.html#operand-bundles) support to MLIR LLVM dialect. It affects these 3 operations related to making function calls: `llvm.call`, `llvm.invoke`, and `llvm.call_intrinsic`. This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand `op_bundle_operands` that contains the SSA values for operand bundles. The second parameter is a property `op_bundle_tags` which holds an array of strings that represent the tags of each operand bundle.
Are the operand bundle a very core part of the llvm call ops? IIUC it's just auxiliary data? I suppose most of the time a user won't need to specify anything there. For such cases it would be nice to have some default builder or helper functions to make it transparent for downstream users. Right now it's kinda forcing all downstream users to add their wrappers to just ignore these newly added fields, e.g., triton-lang/triton#4847. |
@antiagainst Hi Zhang, thanks for your feedback!
llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td Lines 685 to 700 in 1e5e153
For llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td Lines 567 to 576 in 1e5e153
All of the above custom builders set the relevant attributes to their default values and you don't have to specify the operand bundle related attributes by hand if you use these builders. We could even add more builders if these builders do not cover your use cases. I noticed that your code uses this builder overload: static void build(mlir::OpBuilder &, mlir::OperationState &odsState, mlir::TypeRange resultTypes, mlir::ValueRange operands, llvm::ArrayRef<mlir::NamedAttribute> attributes = {}); Implementation of this builder overload is generated automatically by MLIR and it lets you build the operation with every detail by hand. If you're using this overload you have to specify the operand bundle related attributes manually. I'm afraid we couldn't do much about this since this builder is generated by MLIR. However |
This updates LLVM to pull in two fixes we need for AMD: * llvm/llvm-project#110553 * llvm/llvm-project#104743 Fixed `LLVM::CallOp` and `LLVM::CallIntrinsicOp` builder API after * llvm/llvm-project#108933
Extra builders for CallIntrinsicOp. This is inspired by the comment from @antiagainst from [here](#108933 (comment)).
Extra builders for CallIntrinsicOp. This is inspired by the comment from @antiagainst from [here](llvm#108933 (comment)).
This updates LLVM to pull in two fixes we need for AMD: * llvm/llvm-project#110553 * llvm/llvm-project#104743 Fixed `LLVM::CallOp` and `LLVM::CallIntrinsicOp` builder API after * llvm/llvm-project#108933
This updates LLVM to pull in two fixes we need for AMD: * llvm/llvm-project#110553 * llvm/llvm-project#104743 Fixed `LLVM::CallOp` and `LLVM::CallIntrinsicOp` builder API after * llvm/llvm-project#108933
@Lancern why are these additions not marked optional in the op definition? |
This PR adds LLVM operand bundle support to MLIR LLVM dialect. It affects these 3 operations related to making function calls:
llvm.call
,llvm.invoke
, andllvm.call_intrinsic
.This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand
op_bundle_operands
that contains the SSA values for operand bundles. The second parameter is a propertyop_bundle_tags
which holds an array of strings that represent the tags of each operand bundle.