Skip to content

Commit 7a7bc80

Browse files
committed
[mlir][LLVM] Add operand bundle support
1 parent 30d7dcc commit 7a7bc80

File tree

7 files changed

+414
-39
lines changed

7 files changed

+414
-39
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,13 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
551551
Variadic<LLVM_Type>:$normalDestOperands,
552552
Variadic<LLVM_Type>:$unwindDestOperands,
553553
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
554-
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
554+
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
555+
VariadicOfVariadic<LLVM_Type,
556+
"op_bundle_sizes">:$op_bundle_operands,
557+
DenseI32ArrayAttr:$op_bundle_sizes,
558+
OptionalProperty<
559+
ArrayProperty<StringProperty, "operand bundle tags">
560+
>:$op_bundle_tags);
555561
let results = (outs Optional<LLVM_Type>:$result);
556562
let successors = (successor AnySuccessor:$normalDest,
557563
AnySuccessor:$unwindDest);
@@ -587,7 +593,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
587593
//===----------------------------------------------------------------------===//
588594

589595
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
590-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
596+
[AttrSizedOperandSegments,
597+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
591598
DeclareOpInterfaceMethods<CallOpInterface>,
592599
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
593600
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -641,8 +648,13 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
641648
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
642649
OptionalAttr<UnitAttr>:$convergent,
643650
OptionalAttr<UnitAttr>:$no_unwind,
644-
OptionalAttr<UnitAttr>:$will_return
645-
);
651+
OptionalAttr<UnitAttr>:$will_return,
652+
VariadicOfVariadic<LLVM_Type,
653+
"op_bundle_sizes">:$op_bundle_operands,
654+
DenseI32ArrayAttr:$op_bundle_sizes,
655+
OptionalProperty<
656+
ArrayProperty<StringProperty, "operand bundle tags">
657+
>:$op_bundle_tags);
646658
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
647659
let arguments = !con(args, aliasAttrs);
648660
let results = (outs Optional<LLVM_Type>:$result);
@@ -662,6 +674,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
662674
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
663675
CArg<"ValueRange", "{}">:$args)>
664676
];
677+
let hasVerifier = 1;
665678
let hasCustomAssemblyFormat = 1;
666679
let extraClassDeclaration = [{
667680
/// Returns the callee function type.
@@ -1875,21 +1888,33 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
18751888

18761889
def LLVM_CallIntrinsicOp
18771890
: LLVM_Op<"call_intrinsic",
1878-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1891+
[AttrSizedOperandSegments,
1892+
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
18791893
let summary = "Call to an LLVM intrinsic function.";
18801894
let description = [{
18811895
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
18821896
the MLIR function type of this op to determine which intrinsic to call.
18831897
}];
18841898
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
18851899
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1886-
"{}">:$fastmathFlags);
1900+
"{}">:$fastmathFlags,
1901+
VariadicOfVariadic<LLVM_Type,
1902+
"op_bundle_sizes">:$op_bundle_operands,
1903+
DenseI32ArrayAttr:$op_bundle_sizes,
1904+
OptionalProperty<
1905+
ArrayProperty<StringProperty, "operand bundle tags">
1906+
>:$op_bundle_tags);
18871907
let results = (outs Optional<LLVM_Type>:$results);
18881908
let llvmBuilder = [{
18891909
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
18901910
}];
18911911
let assemblyFormat = [{
1892-
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
1912+
$intrin `(` $args `)`
1913+
( `bundlearg` `(` $op_bundle_operands^ `)` )?
1914+
( `bundletag` `(` $op_bundle_tags^ `)` )?
1915+
`:` functional-type($args, $results)
1916+
( `bundletype` `(` type($op_bundle_operands)^ `)` )?
1917+
attr-dict
18931918
}];
18941919

18951920
let hasVerifier = 1;

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,12 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
544544
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545545
promoted, callOp->getAttrs());
546546

547+
newOp->setAttr(newOp.getOperandSegmentSizesAttrName(),
548+
rewriter.getDenseI32ArrayAttr(
549+
{static_cast<int32_t>(promoted.size()), 0}));
550+
newOp->setAttr(newOp.getOpBundleSizesAttrName(),
551+
rewriter.getDenseI32ArrayAttr({}));
552+
547553
SmallVector<Value, 4> results;
548554
if (numResults < 2) {
549555
// If < 2 results, packing did not do anything and we can just return.

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,17 +837,29 @@ class FunctionCallPattern
837837
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const override {
839839
if (callOp.getNumResults() == 0) {
840-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
840+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841841
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
842+
newOp->setAttr(
843+
newOp.getOperandSegmentSizesAttrName(),
844+
rewriter.getDenseI32ArrayAttr(
845+
{static_cast<int32_t>(adaptor.getOperands().size()), 0}));
846+
newOp->setAttr(newOp.getOpBundleSizesAttrName(),
847+
rewriter.getDenseI32ArrayAttr({}));
842848
return success();
843849
}
844850

845851
// Function returns a single result.
846852
auto dstType = typeConverter.convertType(callOp.getType(0));
847853
if (!dstType)
848854
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
849-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
850856
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
857+
newOp->setAttr(
858+
newOp.getOperandSegmentSizesAttrName(),
859+
rewriter.getDenseI32ArrayAttr(
860+
{static_cast<int32_t>(adaptor.getOperands().size()), 0}));
861+
newOp->setAttr(newOp.getOpBundleSizesAttrName(),
862+
rewriter.getDenseI32ArrayAttr({}));
851863
return success();
852864
}
853865
};

0 commit comments

Comments
 (0)