Skip to content

Commit 7a41798

Browse files
committed
[mlir][LLVM] Add operand bundle support
1 parent 30d7dcc commit 7a41798

File tree

7 files changed

+401
-43
lines changed

7 files changed

+401
-43
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
551551
Variadic<LLVM_Type>:$normalDestOperands,
552552
Variadic<LLVM_Type>:$unwindDestOperands,
553553
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
554-
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
554+
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
555+
VariadicOfVariadic<LLVM_Type,
556+
"op_bundle_sizes">:$op_bundle_operands,
557+
DenseI32ArrayAttr:$op_bundle_sizes,
558+
DefaultValuedProperty<
559+
ArrayProperty<StringProperty, "operand bundle tags">,
560+
"ArrayRef<std::string>{}",
561+
"SmallVector<std::string>{}"
562+
>:$op_bundle_tags);
555563
let results = (outs Optional<LLVM_Type>:$result);
556564
let successors = (successor AnySuccessor:$normalDest,
557565
AnySuccessor:$unwindDest);
@@ -587,7 +595,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
587595
//===----------------------------------------------------------------------===//
588596

589597
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
590-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
598+
[AttrSizedOperandSegments,
599+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
591600
DeclareOpInterfaceMethods<CallOpInterface>,
592601
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
593602
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -641,8 +650,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
641650
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
642651
OptionalAttr<UnitAttr>:$convergent,
643652
OptionalAttr<UnitAttr>:$no_unwind,
644-
OptionalAttr<UnitAttr>:$will_return
645-
);
653+
OptionalAttr<UnitAttr>:$will_return,
654+
VariadicOfVariadic<LLVM_Type,
655+
"op_bundle_sizes">:$op_bundle_operands,
656+
DenseI32ArrayAttr:$op_bundle_sizes,
657+
DefaultValuedProperty<
658+
ArrayProperty<StringProperty, "operand bundle tags">,
659+
"ArrayRef<std::string>{}",
660+
"SmallVector<std::string>{}"
661+
>:$op_bundle_tags);
646662
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
647663
let arguments = !con(args, aliasAttrs);
648664
let results = (outs Optional<LLVM_Type>:$result);
@@ -662,6 +678,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
662678
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
663679
CArg<"ValueRange", "{}">:$args)>
664680
];
681+
let hasVerifier = 1;
665682
let hasCustomAssemblyFormat = 1;
666683
let extraClassDeclaration = [{
667684
/// Returns the callee function type.
@@ -1875,21 +1892,34 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
18751892

18761893
def LLVM_CallIntrinsicOp
18771894
: LLVM_Op<"call_intrinsic",
1878-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1895+
[AttrSizedOperandSegments,
1896+
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
18791897
let summary = "Call to an LLVM intrinsic function.";
18801898
let description = [{
18811899
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
18821900
the MLIR function type of this op to determine which intrinsic to call.
18831901
}];
18841902
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
18851903
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1886-
"{}">:$fastmathFlags);
1904+
"{}">:$fastmathFlags,
1905+
VariadicOfVariadic<LLVM_Type,
1906+
"op_bundle_sizes">:$op_bundle_operands,
1907+
DenseI32ArrayAttr:$op_bundle_sizes,
1908+
DefaultValuedProperty<
1909+
ArrayProperty<StringProperty, "operand bundle tags">,
1910+
"ArrayRef<std::string>{}",
1911+
"SmallVector<std::string>{}"
1912+
>:$op_bundle_tags);
18871913
let results = (outs Optional<LLVM_Type>:$results);
18881914
let llvmBuilder = [{
18891915
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
18901916
}];
18911917
let assemblyFormat = [{
1892-
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
1918+
$intrin `(` $args `)`
1919+
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
1920+
$op_bundle_tags)^ )?
1921+
`:` functional-type($args, $results)
1922+
attr-dict
18931923
}];
18941924

18951925
let hasVerifier = 1;

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
544544
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545545
promoted, callOp->getAttrs());
546546

547+
newOp.getProperties().operandSegmentSizes = {
548+
static_cast<int32_t>(promoted.size()), 0};
549+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
550+
547551
SmallVector<Value, 4> results;
548552
if (numResults < 2) {
549553
// If < 2 results, packing did not do anything and we can just return.

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,17 +837,23 @@ class FunctionCallPattern
837837
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const override {
839839
if (callOp.getNumResults() == 0) {
840-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
840+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841841
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
842+
newOp.getProperties().operandSegmentSizes = {
843+
static_cast<int32_t>(adaptor.getOperands().size()), 0};
844+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
842845
return success();
843846
}
844847

845848
// Function returns a single result.
846849
auto dstType = typeConverter.convertType(callOp.getType(0));
847850
if (!dstType)
848851
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
849-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
852+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
850853
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
854+
newOp.getProperties().operandSegmentSizes = {
855+
static_cast<int32_t>(adaptor.getOperands().size()), 0};
856+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
851857
return success();
852858
}
853859
};

0 commit comments

Comments
 (0)