Skip to content

[mlir][LLVM] Add operand bundle support #108933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 26, 2024
Merged

Conversation

Lancern
Copy link
Member

@Lancern Lancern commented Sep 17, 2024

This PR adds LLVM operand bundle support to MLIR LLVM dialect. It affects these 3 operations related to making function calls: llvm.call, llvm.invoke, and llvm.call_intrinsic.

This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand op_bundle_operands that contains the SSA values for operand bundles. The second parameter is a property op_bundle_tags which holds an array of strings that represent the tags of each operand bundle.

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Sirui Mu (Lancern)

Changes

This PR adds LLVM operand bundle support to MLIR LLVM dialect. It affects these 3 operations related to making function calls: llvm.call, llvm.invoke, and llvm.call_intrinsic.

This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand bundle_operands that contains the SSA values for operand bundles. The second parameter is a #llvm.opbundles attribute, which is basically a list of #llvm.opbundle attribute. A #llvm.opbundle attribute provides information about a single operand bundle. It includes a string tag, and a list of integers which index into the bundle_operands parameter to indicate the SSA values included in the operand bundle.


Patch is 37.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108933.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+29)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+19-6)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+4)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+10-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+157-29)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (+50-12)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+47-5)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+47)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2da45eba77655b..67d43e4d2e0657 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1210,4 +1210,33 @@ def WorkgroupAttributionAttr
   let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// OperandBundleAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_OperandBundleAttr : LLVM_Attr<"OperandBundle", "opbundle"> {
+  let summary = "Operand bundle information";
+  let description = [{
+    Provide information about a single operand bundle. Each operand bundle has a
+    string tag together with various number of SSA value uses. The SSA values
+    are specified through indices into the operation's operand bundle operands.
+  }];
+
+  let parameters = (ins "StringAttr":$tag,
+                        OptionalArrayRefParameter<"uint32_t">:$argIndices);
+  let assemblyFormat = [{
+    `<` $tag (`,` $argIndices^)? `>`
+  }];
+}
+
+def LLVM_OperandBundlesAttr : LLVM_Attr<"OperandBundles", "opbundles"> {
+  let summary = "A list of operand bundle attributes";
+  let description = "A list of operand bundle attributes";
+
+  let parameters = (ins ArrayRefParameter<"OperandBundleAttr">:$bundles);
+  let assemblyFormat = [{
+    `<` $bundles `>`
+  }];
+}
+
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d956d7f27f784d..b10a7bbad8eea2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -550,8 +550,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
+                   Variadic<LLVM_Type>:$bundle_operands,
                    OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
-                   DefaultValuedAttr<CConv, "CConv::C">:$CConv);
+                   DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+                   OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
   let results = (outs Optional<LLVM_Type>:$result);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -587,7 +589,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 //===----------------------------------------------------------------------===//
 
 def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
-                    [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+                    [AttrSizedOperandSegments,
+                     DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                      DeclareOpInterfaceMethods<CallOpInterface>,
                      DeclareOpInterfaceMethods<SymbolUserOpInterface>,
                      DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -633,6 +636,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
   dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
+                  Variadic<LLVM_Type>:$bundle_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
@@ -641,7 +645,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
                   OptionalAttr<UnitAttr>:$convergent,
                   OptionalAttr<UnitAttr>:$no_unwind,
-                  OptionalAttr<UnitAttr>:$will_return
+                  OptionalAttr<UnitAttr>:$will_return,
+                  OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles
                   );
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
@@ -662,6 +667,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
+  let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Returns the callee function type.
@@ -1875,21 +1881,28 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
 
 def LLVM_CallIntrinsicOp
     : LLVM_Op<"call_intrinsic",
-              [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
+              [AttrSizedOperandSegments,
+               DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
   let summary = "Call to an LLVM intrinsic function.";
   let description = [{
     Call the specified llvm intrinsic. If the intrinsic is overloaded, use
     the MLIR function type of this op to determine which intrinsic to call.
   }];
   let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
+                       Variadic<LLVM_Type>:$bundle_operands,
                        DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                         "{}">:$fastmathFlags);
+                                         "{}">:$fastmathFlags,
+                       OptionalAttr<LLVM_OperandBundlesAttr>:$op_bundles);
   let results = (outs Optional<LLVM_Type>:$results);
   let llvmBuilder = [{
     return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
   }];
   let assemblyFormat = [{
-    $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
+    $intrin `(` $args `)`
+    ( `bundlearg` `(` $bundle_operands^ `)` )?
+    `:` functional-type($args, $results)
+    ( `,` `tuple` `<` type($bundle_operands)^ `>` )?
+    attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4c2e8682285c52..85ec6031dfee70 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
         promoted, callOp->getAttrs());
 
+    newOp->setAttr(newOp.getOperandSegmentSizesAttrName(),
+                   rewriter.getDenseI32ArrayAttr(
+                       {static_cast<int32_t>(promoted.size()), 0}));
+
     SmallVector<Value, 4> results;
     if (numResults < 2) {
       // If < 2 results, packing did not do anything and we can just return.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ca786316324198..6bb8203da1898a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -837,8 +837,12 @@ class FunctionCallPattern
   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (callOp.getNumResults() == 0) {
-      rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+      auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
           callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
+      newOp->setAttr(
+          newOp.getOperandSegmentSizesAttrName(),
+          rewriter.getDenseI32ArrayAttr(
+              {static_cast<int32_t>(adaptor.getOperands().size()), 0}));
       return success();
     }
 
@@ -846,8 +850,12 @@ class FunctionCallPattern
     auto dstType = typeConverter.convertType(callOp.getType(0));
     if (!dstType)
       return rewriter.notifyMatchFailure(callOp, "type conversion failed");
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+    auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
+    newOp->setAttr(
+        newOp.getOperandSegmentSizesAttrName(),
+        rewriter.getDenseI32ArrayAttr(
+            {static_cast<int32_t>(adaptor.getOperands().size()), 0}));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 205d7494d4378c..bb7df718ccad61 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -949,12 +949,14 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*var_callee_type=*/nullptr, callee, args, /*bundle_operands=*/{},
+        /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -975,11 +977,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), callee, args,
+        /*bundle_operands=*/{},
         /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr,
-        /*no_unwind=*/nullptr, /*will_return=*/nullptr,
+        /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundles=*/nullptr,
         /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -988,12 +991,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType),
-        /*callee=*/nullptr, args,
+        /*callee=*/nullptr, args, /*bundle_operands=*/{},
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
-        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+        /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
@@ -1001,11 +1004,12 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
+        /*bundle_operands=*/{},
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
-        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
-        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+        /*op_bundles=*/nullptr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
 CallInterfaceCallable CallOp::getCallableForCallee() {
@@ -1027,7 +1031,7 @@ void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 }
 
 Operation::operand_range CallOp::getArgOperands() {
-  return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
+  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
 MutableOperandRange CallOp::getArgOperandsMutable() {
@@ -1100,6 +1104,38 @@ LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
   return success();
 }
 
+template <typename OpType>
+static LogicalResult verifyOperandBundleOperands(OpType &op) {
+  ValueRange opBundleOperands = op.getBundleOperands();
+  OperandBundlesAttr opBundles = op.getOpBundlesAttr();
+
+  if (!opBundles) {
+    if (!opBundleOperands.empty())
+      return op.emitError("expected operand bundles attribute");
+    return success();
+  }
+
+  DenseSet<uint32_t> seenOperandIdx;
+  for (OperandBundleAttr bundle : opBundles.getBundles()) {
+    for (uint32_t bundleOperandIdx : bundle.getArgIndices()) {
+      if (bundleOperandIdx >= opBundleOperands.size())
+        return op.emitError("operand bundle argument index ")
+               << bundleOperandIdx << " is out of range";
+      seenOperandIdx.insert(bundleOperandIdx);
+    }
+  }
+
+  for (uint32_t idx = 0; idx < opBundleOperands.size(); ++idx) {
+    if (!seenOperandIdx.contains(idx))
+      return op.emitError("operand bundle argument at index ")
+             << idx << " is not included in any operand bundles";
+  }
+
+  return success();
+}
+
+LogicalResult CallOp::verify() { return verifyOperandBundleOperands(*this); }
+
 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(verifyCallOpVarCalleeType(*this)))
     return failure();
@@ -1150,15 +1186,15 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   // Verify that the operand and result types match the callee.
 
   if (!funcType.isVarArg() &&
-      funcType.getNumParams() != (getNumOperands() - isIndirect))
+      funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
     return emitOpError() << "incorrect number of operands ("
-                         << (getNumOperands() - isIndirect)
+                         << (getCalleeOperands().size() - isIndirect)
                          << ") for callee (expecting: "
                          << funcType.getNumParams() << ")";
 
-  if (funcType.getNumParams() > (getNumOperands() - isIndirect))
+  if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
     return emitOpError() << "incorrect number of operands ("
-                         << (getNumOperands() - isIndirect)
+                         << (getCalleeOperands().size() - isIndirect)
                          << ") for varargs callee (expecting at least: "
                          << funcType.getNumParams() << ")";
 
@@ -1208,16 +1244,24 @@ void CallOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
-  auto args = getOperands().drop_front(isDirect ? 0 : 1);
+  auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
   // Print the variadic callee type if the call is variadic.
   if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
     p << " vararg(" << *varCalleeType << ")";
 
+  // Print the operand bundles, if any.
+  if (!getBundleOperands().empty()) {
+    p << " bundlearg(";
+    p.printOperands(getBundleOperands());
+    p << ")";
+  }
+
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
                           {getCalleeAttrName(), getTailCallKindAttrName(),
-                           getVarCalleeTypeAttrName(), getCConvAttrName()});
+                           getVarCalleeTypeAttrName(), getCConvAttrName(),
+                           getOperandSegmentSizesAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1225,24 +1269,53 @@ void CallOp::print(OpAsmPrinter &p) {
 
   // Reconstruct the function MLIR function type from operand and result types.
   p.printFunctionalType(args.getTypes(), getResultTypes());
+
+  if (!getBundleOperands().empty()) {
+    SmallVector<Type> opBundleArgTypes;
+    opBundleArgTypes.reserve(getBundleOperands().size());
+    for (auto opBundleArg : getBundleOperands())
+      opBundleArgTypes.push_back(opBundleArg.getType());
+
+    p << ", tuple<";
+    llvm::interleaveComma(opBundleArgTypes, p);
+    p << ">";
+  }
 }
 
 /// Parses the type of a call operation and resolves the operands if the parsing
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands,
+    ArrayRef<OpAsmParser::UnresolvedOperand> opBundleOperands) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
   if (parser.parseColonTypeList(types))
     return failure();
 
-  if (isDirect && types.size() != 1)
-    return parser.emitError(trailingTypesLoc,
-                            "expected direct call to have 1 trailing type");
-  if (!isDirect && types.size() != 2)
+  if (isDirect && opBundleOperands.empty() && types.size() != 1)
+    return parser.emitError(
+        trailingTypesLoc,
+        "expected direct call without operand bundles to have 1 trailing type");
+  if (isDirect && !opBundleOperands.empty() && types.size() != 2)
+    return parser.emitError(
+        trailingTypesLoc,
+        "expected direct call with operand bundles to have 2 trailing types");
+  if (!isDirect && opBundleOperands.empty() && types.size() != 2)
     return parser.emitError(trailingTypesLoc,
-                            "expected indirect call to have 2 trailing types");
+                            "expected indirect call without operand bundles to "
+                            "have 2 trailing types");
+  if (!isDirect && !opBundleOperands.empty() && types.size() != 3)
+    return parser.emitError(
+        trailingTypesLoc,
+        "expected indirect call with operand bundles to have 3 trailing types");
+
+  TupleType opBundleTypes;
+  if (!opBundleOperands.empty()) {
+    opBundleTypes = llvm::dyn_cast<TupleType>(types.pop_back_val());
+    if (!opBundleTypes)
+      return parser.emitError(trailingTypesLoc, "expected trailing tuple type");
+  }
 
   auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
   if (!funcType)
@@ -1267,6 +1340,12 @@ static ParseResult parseCallTypeAndResolveOperands(
   if (funcType.getNumResults() != 0)
     result.addTypes(funcType.getResults());
 
+  if (!opBundleOperands.empty()) {
+    if (parser.resolveOperands(opBundleOperands, opBundleTypes.getTypes(),
+                               parser.getNameLoc(), result.operands))
+      return failure();
+  }
+
   return success();
 }
 
@@ -1288,7 +1367,9 @@ static ParseResult parseOptionalCallFuncPtr(
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
 //                             ( `vararg(` var-callee-type `)` )?
+//                             ( `bundlearg(` ssa-use-list `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
+//                             (`,` `tuple` `<` type-list `>`)?
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
   TypeAttr varCalleeType;
@@ -1333,11 +1414,25 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
+  SmallVector<OpAsmParser::UnresolvedOperand> opBundleOperands;
+  bool hasOpBundles = parser.parseOptionalKeyword("bundlearg").succeeded();
+  if (hasOpBundles &&
+      parser.parseOperandList(opBundleOperands, OpAsmParser::Delimiter::Paren))
+    return failure();
+
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      opBundleOperands))
+    return failure();
+
+  result.addAttribute(CallOp::getOperandSegmentSizeAttr(),
+                      parser.getBuilder().getDenseI32ArrayAttr(
+                          {static_cast<int32_t>(operands.size()),
+                           static_cast<int32_t>(opBundleOperands.size())}));
+  retu...
[truncated]

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we cannot use VariadicOfVariadic for bundle_operands?
op_bundles would then need to be a DenseI32ArrayAttr, but ought to be kept up to sync automatically by all APIs I think.

@Lancern
Copy link
Member Author

Lancern commented Sep 17, 2024

Is there a reason we cannot use VariadicOfVariadic for bundle_operands?
op_bundles would then need to be a DenseI32ArrayAttr, but ought to be kept up to sync automatically by all APIs I think.

Each operand bundle has a string tag associated with it, thus I have to invent the #llvm.opbundle stuff to store this tag. Do you suggest that the string tags should be kept separately?

@zero9178
Copy link
Member

Do you suggest that the string tags should be kept separately?

I believe this should result in less code and things that could go out of sync. Keeping the number of operands in the current Variadic in sync with the indices used in the operand bundles rather difficult and the implementation detail of the ranges being denoted via indices even leaks into the assembly format.

With the VariadicOfVariadic the only thing that needs to be kept in sync is the number of variadic ranges and the number of corresponding string tags. We could then write a custom printer and parser that prints each operand range together with its string tag.

@Lancern
Copy link
Member Author

Lancern commented Sep 17, 2024

@zero9178 I'll update to follow your proposed approach.

@Dinistro Dinistro self-requested a review September 17, 2024 14:27
@Lancern Lancern force-pushed the mlir-llvmir-opbundle branch from bbb2bc2 to 7a7bc80 Compare September 19, 2024 13:52
@Lancern
Copy link
Member Author

Lancern commented Sep 19, 2024

@zero9178 I've updated to use VariadicOfVariadic for op_bundle_operands.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much! This is pretty much as I imagined conceptually.

I still have one more high-level design question about the syntax before I will do a more thorough pass over the actual code, though I did leave some comments in the code where I noticed issues.

@Lancern Lancern force-pushed the mlir-llvmir-opbundle branch from 7a7bc80 to 7a41798 Compare September 23, 2024 15:20
@Lancern
Copy link
Member Author

Lancern commented Sep 23, 2024

@zero9178 Hi I've updated the PR to use the proposed syntax.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thank you so much! I added some nits regarding code-style and found oen case that I believe is a bug. Otherwise I think this will soon be good to go

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

I added some mostly minor comments. It seems like some verifier tests and probably also some roundtrip.mlir tests for printing and parsing could make sense?

If you have time to add support for the llvm.intr.assume that would be great.

I think mid term we also want to support operand bundles when importing from LLVM IR into MLIR LLVM dialect (ModuleImport.cpp). However, given the size of the PR it may make sense to do that in a later stage.

llvm.func @call_intrin_with_opbundle(%arg0 : !llvm.ptr) {
%0 = llvm.mlir.constant(1 : i1) : i1
%1 = llvm.mlir.constant(16 : i32) : i32
llvm.call_intrinsic "llvm.assume"(%0) ["align"(%arg0, %1 : !llvm.ptr, i32)] : (i1) -> ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also an "llvm.intr.assume" intrinsic that would benefit from operand bundle support. It is defined in LLVMIntrinsicOps.td. I wonder if this could be updated as well? AFAIK most intrinsics are nowadays represented using specific operations rather than with the generic llvm.call_intrinsic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make another PR for llvm.intr.assume after this PR lands. After all this PR is about operand bundles :)

@Lancern
Copy link
Member Author

Lancern commented Sep 24, 2024

@zero9178 @gysit I've updated the code again to resolve the nits and add more tests.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is worth having tests in roundtrip.mlir since the printing and parsing is non-trivial here?

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once @gysit's comments are addressed and he is happy 🙂

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tests! LGTM

@Lancern Lancern force-pushed the mlir-llvmir-opbundle branch from 7a54369 to 4351398 Compare September 24, 2024 14:05
@Lancern
Copy link
Member Author

Lancern commented Sep 24, 2024

Rebased onto the latest main.

@Lancern
Copy link
Member Author

Lancern commented Sep 24, 2024

BTW I don't have commit access so if all reviewers are happy about this PR someone needs to land it.

@Dinistro
Copy link
Contributor

There seems to be a flang issue caused by producing an invalid llvm.call operation.

@gysit
Copy link
Contributor

gysit commented Sep 25, 2024

Once CI is green one of us will happily click the button.

As this is not your first contribution - and since we hope for many more :) - you can consider to apply for commit access (https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access) once this landed.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Sep 25, 2024
@Lancern
Copy link
Member Author

Lancern commented Sep 25, 2024

There seems to be a flang issue caused by producing an invalid llvm.call operation.

Fixed these regressions in my local environment. Let's see if the CI passes.

As this is not your first contribution - and since we hope for many more :) - you can consider to apply for commit access (https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access) once this landed.

Great, I'll apply for the commit access after this PR lands!

@Lancern
Copy link
Member Author

Lancern commented Sep 26, 2024

The CI is green now.

@gysit gysit merged commit fde3c16 into llvm:main Sep 26, 2024
11 checks passed
@Lancern Lancern deleted the mlir-llvmir-opbundle branch September 26, 2024 06:10
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
This PR adds LLVM [operand
bundle](https://llvm.org/docs/LangRef.html#operand-bundles) support to
MLIR LLVM dialect. It affects these 3 operations related to making
function calls: `llvm.call`, `llvm.invoke`, and `llvm.call_intrinsic`.

This PR adds two new parameters to each of the 3 operations. The first
parameter is a variadic operand `op_bundle_operands` that contains the
SSA values for operand bundles. The second parameter is a property
`op_bundle_tags` which holds an array of strings that represent the tags
of each operand bundle.
@antiagainst
Copy link
Member

Are the operand bundle a very core part of the llvm call ops? IIUC it's just auxiliary data? I suppose most of the time a user won't need to specify anything there. For such cases it would be nice to have some default builder or helper functions to make it transparent for downstream users. Right now it's kinda forcing all downstream users to add their wrappers to just ignore these newly added fields, e.g., triton-lang/triton#4847.

@Lancern
Copy link
Member Author

Lancern commented Oct 4, 2024

@antiagainst Hi Zhang, thanks for your feedback!

llvm.call and llvm.invoke already have some convenient builders. For llvm.call:

let builders = [
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
OpBuilder<(ins "LLVMFunctionType":$calleeType, "ValueRange":$args)>,
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>
];

For llvm.invoke:

let builders = [
OpBuilder<(ins "LLVMFuncOp":$func,
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
"Block*":$unwind, "ValueRange":$unwindOps)>,
OpBuilder<(ins "TypeRange":$tys, "FlatSymbolRefAttr":$callee,
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
"Block*":$unwind, "ValueRange":$unwindOps)>,
OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
"ValueRange":$ops, "Block*":$normal, "ValueRange":$normalOps,
"Block*":$unwind, "ValueRange":$unwindOps)>];

All of the above custom builders set the relevant attributes to their default values and you don't have to specify the operand bundle related attributes by hand if you use these builders. We could even add more builders if these builders do not cover your use cases.

I noticed that your code uses this builder overload:

static void build(mlir::OpBuilder &, mlir::OperationState &odsState, mlir::TypeRange resultTypes, mlir::ValueRange operands, llvm::ArrayRef<mlir::NamedAttribute> attributes = {});

Implementation of this builder overload is generated automatically by MLIR and it lets you build the operation with every detail by hand. If you're using this overload you have to specify the operand bundle related attributes manually. I'm afraid we couldn't do much about this since this builder is generated by MLIR.

However llvm.call_intrinsics indeed does not have a custom builder that avoids the burden of specifying operand bundle attributes. We definitely could improve this!

antiagainst added a commit to triton-lang/triton that referenced this pull request Oct 4, 2024
This updates LLVM to pull in two fixes we need for AMD:

* llvm/llvm-project#110553
* llvm/llvm-project#104743

Fixed `LLVM::CallOp` and `LLVM::CallIntrinsicOp` builder API after
* llvm/llvm-project#108933
FMarno added a commit that referenced this pull request Oct 10, 2024
Extra builders for CallIntrinsicOp.
This is inspired by the comment from @antiagainst from
[here](#108933 (comment)).
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
Extra builders for CallIntrinsicOp.
This is inspired by the comment from @antiagainst from
[here](llvm#108933 (comment)).
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This updates LLVM to pull in two fixes we need for AMD:

* llvm/llvm-project#110553
* llvm/llvm-project#104743

Fixed `LLVM::CallOp` and `LLVM::CallIntrinsicOp` builder API after
* llvm/llvm-project#108933
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
This updates LLVM to pull in two fixes we need for AMD:

* llvm/llvm-project#110553
* llvm/llvm-project#104743

Fixed `LLVM::CallOp` and `LLVM::CallIntrinsicOp` builder API after
* llvm/llvm-project#108933
@nicolasvasilache
Copy link
Contributor

@Lancern why are these additions not marked optional in the op definition?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants