Skip to content

Commit fde3c16

Browse files
authored
[mlir][LLVM] Add operand bundle support (#108933)
This PR adds LLVM [operand bundle](https://llvm.org/docs/LangRef.html#operand-bundles) support to MLIR LLVM dialect. It affects these 3 operations related to making function calls: `llvm.call`, `llvm.invoke`, and `llvm.call_intrinsic`. This PR adds two new parameters to each of the 3 operations. The first parameter is a variadic operand `op_bundle_operands` that contains the SSA values for operand bundles. The second parameter is a property `op_bundle_tags` which holds an array of strings that represent the tags of each operand bundle.
1 parent 23487be commit fde3c16

File tree

9 files changed

+541
-53
lines changed

9 files changed

+541
-53
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,26 @@ static unsigned getLenParamFieldId(mlir::Type ty) {
110110
return getTypeDescFieldId(ty) + 1;
111111
}
112112

113+
static llvm::SmallVector<mlir::NamedAttribute>
114+
addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
115+
llvm::ArrayRef<mlir::NamedAttribute> attrs,
116+
int32_t numCallOperands) {
117+
llvm::SmallVector<mlir::NamedAttribute> newAttrs;
118+
newAttrs.reserve(attrs.size() + 2);
119+
120+
for (mlir::NamedAttribute attr : attrs) {
121+
if (attr.getName() != "operandSegmentSizes")
122+
newAttrs.push_back(attr);
123+
}
124+
125+
newAttrs.push_back(rewriter.getNamedAttr(
126+
"operandSegmentSizes",
127+
rewriter.getDenseI32ArrayAttr({numCallOperands, 0})));
128+
newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes",
129+
rewriter.getDenseI32ArrayAttr({})));
130+
return newAttrs;
131+
}
132+
113133
namespace {
114134
/// Lower `fir.address_of` operation to `llvm.address_of` operation.
115135
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
@@ -229,7 +249,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
229249
mlir::NamedAttribute attr = rewriter.getNamedAttr(
230250
"callee", mlir::SymbolRefAttr::get(memSizeFn));
231251
auto call = rewriter.create<mlir::LLVM::CallOp>(
232-
loc, ity, lenParams, llvm::ArrayRef<mlir::NamedAttribute>{attr});
252+
loc, ity, lenParams,
253+
addLLVMOpBundleAttrs(rewriter, {attr}, lenParams.size()));
233254
size = call.getResult();
234255
llvmObjectType = ::getI8Type(alloc.getContext());
235256
} else {
@@ -559,7 +580,9 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
559580
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
560581
attrConvert(call);
561582
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
562-
call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
583+
call, resultTys, adaptor.getOperands(),
584+
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
585+
adaptor.getOperands().size()));
563586
return mlir::success();
564587
}
565588
};
@@ -980,7 +1003,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
9801003
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
9811004
heap->setAttr("callee", getMalloc(heap, rewriter));
9821005
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
983-
heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
1006+
heap, ::getLlvmPtrType(heap.getContext()), size,
1007+
addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 1));
9841008
return mlir::success();
9851009
}
9861010

@@ -1037,9 +1061,9 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
10371061
mlir::ConversionPatternRewriter &rewriter) const override {
10381062
mlir::Location loc = freemem.getLoc();
10391063
freemem->setAttr("callee", getFree(freemem, rewriter));
1040-
rewriter.create<mlir::LLVM::CallOp>(loc, mlir::TypeRange{},
1041-
mlir::ValueRange{adaptor.getHeapref()},
1042-
freemem->getAttrs());
1064+
rewriter.create<mlir::LLVM::CallOp>(
1065+
loc, mlir::TypeRange{}, mlir::ValueRange{adaptor.getHeapref()},
1066+
addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 1));
10431067
rewriter.eraseOp(freemem);
10441068
return mlir::success();
10451069
}
@@ -2671,7 +2695,8 @@ struct FieldIndexOpConversion : public fir::FIROpConversion<fir::FieldIndexOp> {
26712695
"field", mlir::IntegerAttr::get(lowerTy().indexType(), index));
26722696
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
26732697
field, lowerTy().offsetType(), adaptor.getOperands(),
2674-
llvm::ArrayRef<mlir::NamedAttribute>{callAttr, fieldAttr});
2698+
addLLVMOpBundleAttrs(rewriter, {callAttr, fieldAttr},
2699+
adaptor.getOperands().size()));
26752700
return mlir::success();
26762701
}
26772702

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
551551
Variadic<LLVM_Type>:$normalDestOperands,
552552
Variadic<LLVM_Type>:$unwindDestOperands,
553553
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
554-
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
554+
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
555+
VariadicOfVariadic<LLVM_Type,
556+
"op_bundle_sizes">:$op_bundle_operands,
557+
DenseI32ArrayAttr:$op_bundle_sizes,
558+
DefaultValuedProperty<
559+
ArrayProperty<StringProperty, "operand bundle tags">,
560+
"ArrayRef<std::string>{}",
561+
"SmallVector<std::string>{}"
562+
>:$op_bundle_tags);
555563
let results = (outs Optional<LLVM_Type>:$result);
556564
let successors = (successor AnySuccessor:$normalDest,
557565
AnySuccessor:$unwindDest);
@@ -607,7 +615,8 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
607615
//===----------------------------------------------------------------------===//
608616

609617
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
610-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
618+
[AttrSizedOperandSegments,
619+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
611620
DeclareOpInterfaceMethods<CallOpInterface>,
612621
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
613622
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -661,8 +670,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
661670
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
662671
OptionalAttr<UnitAttr>:$convergent,
663672
OptionalAttr<UnitAttr>:$no_unwind,
664-
OptionalAttr<UnitAttr>:$will_return
665-
);
673+
OptionalAttr<UnitAttr>:$will_return,
674+
VariadicOfVariadic<LLVM_Type,
675+
"op_bundle_sizes">:$op_bundle_operands,
676+
DenseI32ArrayAttr:$op_bundle_sizes,
677+
DefaultValuedProperty<
678+
ArrayProperty<StringProperty, "operand bundle tags">,
679+
"ArrayRef<std::string>{}",
680+
"SmallVector<std::string>{}"
681+
>:$op_bundle_tags);
666682
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
667683
let arguments = !con(args, aliasAttrs);
668684
let results = (outs Optional<LLVM_Type>:$result);
@@ -682,6 +698,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
682698
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
683699
CArg<"ValueRange", "{}">:$args)>
684700
];
701+
let hasVerifier = 1;
685702
let hasCustomAssemblyFormat = 1;
686703
let extraClassDeclaration = [{
687704
/// Returns the callee function type.
@@ -1895,21 +1912,34 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
18951912

18961913
def LLVM_CallIntrinsicOp
18971914
: LLVM_Op<"call_intrinsic",
1898-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1915+
[AttrSizedOperandSegments,
1916+
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
18991917
let summary = "Call to an LLVM intrinsic function.";
19001918
let description = [{
19011919
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
19021920
the MLIR function type of this op to determine which intrinsic to call.
19031921
}];
19041922
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
19051923
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1906-
"{}">:$fastmathFlags);
1924+
"{}">:$fastmathFlags,
1925+
VariadicOfVariadic<LLVM_Type,
1926+
"op_bundle_sizes">:$op_bundle_operands,
1927+
DenseI32ArrayAttr:$op_bundle_sizes,
1928+
DefaultValuedProperty<
1929+
ArrayProperty<StringProperty, "operand bundle tags">,
1930+
"ArrayRef<std::string>{}",
1931+
"SmallVector<std::string>{}"
1932+
>:$op_bundle_tags);
19071933
let results = (outs Optional<LLVM_Type>:$results);
19081934
let llvmBuilder = [{
19091935
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
19101936
}];
19111937
let assemblyFormat = [{
1912-
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
1938+
$intrin `(` $args `)`
1939+
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
1940+
$op_bundle_tags)^ )?
1941+
`:` functional-type($args, $results)
1942+
attr-dict
19131943
}];
19141944

19151945
let hasVerifier = 1;

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
544544
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545545
promoted, callOp->getAttrs());
546546

547+
newOp.getProperties().operandSegmentSizes = {
548+
static_cast<int32_t>(promoted.size()), 0};
549+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
550+
547551
SmallVector<Value, 4> results;
548552
if (numResults < 2) {
549553
// If < 2 results, packing did not do anything and we can just return.

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,17 +837,23 @@ class FunctionCallPattern
837837
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const override {
839839
if (callOp.getNumResults() == 0) {
840-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
840+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841841
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
842+
newOp.getProperties().operandSegmentSizes = {
843+
static_cast<int32_t>(adaptor.getOperands().size()), 0};
844+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
842845
return success();
843846
}
844847

845848
// Function returns a single result.
846849
auto dstType = typeConverter.convertType(callOp.getType(0));
847850
if (!dstType)
848851
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
849-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
852+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
850853
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
854+
newOp.getProperties().operandSegmentSizes = {
855+
static_cast<int32_t>(adaptor.getOperands().size()), 0};
856+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
851857
return success();
852858
}
853859
};

0 commit comments

Comments
 (0)