Skip to content

Commit 6a8ba31

Browse files
committed
[mlir] Split std.splat into tensor.splat and vector.splat
This is part of the larger effort to split the standard dialect. This will also allow for pruning some additional dependencies on Standard (done in a followup). Differential Revision: https://reviews.llvm.org/D118202
1 parent f7a6c34 commit 6a8ba31

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+504
-419
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -507,55 +507,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
507507
let hasVerifier = 1;
508508
}
509509

510-
//===----------------------------------------------------------------------===//
511-
// SplatOp
512-
//===----------------------------------------------------------------------===//
513-
514-
def SplatOp : Std_Op<"splat", [NoSideEffect,
515-
TypesMatchWith<"operand type matches element type of result",
516-
"aggregate", "input",
517-
"$_self.cast<ShapedType>().getElementType()">]> {
518-
let summary = "splat or broadcast operation";
519-
let description = [{
520-
Broadcast the operand to all elements of the result vector or tensor. The
521-
operand has to be of integer/index/float type. When the result is a tensor,
522-
it has to be statically shaped.
523-
524-
Example:
525-
526-
```mlir
527-
%s = load %A[%i] : memref<128xf32>
528-
%v = splat %s : vector<4xf32>
529-
%t = splat %s : tensor<8x16xi32>
530-
```
531-
532-
TODO: This operation is easy to extend to broadcast to dynamically shaped
533-
tensors in the same way dynamically shaped memrefs are handled.
534-
535-
```mlir
536-
// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
537-
// to the sizes of the two dynamic dimensions.
538-
%m = "foo"() : () -> (index)
539-
%n = "bar"() : () -> (index)
540-
%t = splat %s [%m, %n] : tensor<?x?xi32>
541-
```
542-
}];
543-
544-
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
545-
"integer/index/float type">:$input);
546-
let results = (outs AnyTypeOf<[AnyVectorOfAnyRank,
547-
AnyStaticShapeTensor]>:$aggregate);
548-
549-
let builders = [
550-
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
551-
[{ build($_builder, $_state, aggregateType, element); }]>];
552-
553-
let hasFolder = 1;
554-
let hasVerifier = 1;
555-
556-
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
557-
}
558-
559510
//===----------------------------------------------------------------------===//
560511
// SwitchOp
561512
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,52 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
968968
let hasVerifier = 1;
969969
}
970970

971+
//===----------------------------------------------------------------------===//
972+
// SplatOp
973+
//===----------------------------------------------------------------------===//
974+
975+
def Tensor_SplatOp : Tensor_Op<"splat", [
976+
NoSideEffect,
977+
TypesMatchWith<"operand type matches element type of result",
978+
"aggregate", "input",
979+
"$_self.cast<TensorType>().getElementType()">
980+
]> {
981+
let summary = "tensor splat or broadcast operation";
982+
let description = [{
983+
Broadcast the operand to all elements of the result tensor. The operand is
984+
required to be of integer/index/float type, and the result tensor must be
985+
statically shaped.
986+
987+
Example:
988+
989+
```mlir
990+
%s = arith.constant 10.1 : f32
991+
%t = tensor.splat %s : tensor<8x16xi32>
992+
```
993+
994+
TODO: This operation is easy to extend to broadcast to dynamically shaped
995+
tensors:
996+
997+
```mlir
998+
// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
999+
// to the sizes of the two dynamic dimensions.
1000+
%m = "foo"() : () -> (index)
1001+
%n = "bar"() : () -> (index)
1002+
%t = tensor.splat %s [%m, %n] : tensor<?x?xi32>
1003+
```
1004+
}];
1005+
1006+
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
1007+
"integer/index/float type">:$input);
1008+
let results = (outs AnyStaticShapeTensor:$aggregate);
1009+
1010+
let builders = [
1011+
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
1012+
[{ build($_builder, $_state, aggregateType, element); }]>];
1013+
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
1014+
1015+
let hasFolder = 1;
1016+
}
9711017

9721018
//===----------------------------------------------------------------------===//
9731019
// YieldOp

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,6 +2420,41 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
24202420
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
24212421
}
24222422

2423+
//===----------------------------------------------------------------------===//
2424+
// SplatOp
2425+
//===----------------------------------------------------------------------===//
2426+
2427+
def Vector_SplatOp : Vector_Op<"splat", [
2428+
NoSideEffect,
2429+
TypesMatchWith<"operand type matches element type of result",
2430+
"aggregate", "input",
2431+
"$_self.cast<VectorType>().getElementType()">
2432+
]> {
2433+
let summary = "vector splat or broadcast operation";
2434+
let description = [{
2435+
Broadcast the operand to all elements of the result vector. The operand is
2436+
required to be of integer/index/float type.
2437+
2438+
Example:
2439+
2440+
```mlir
2441+
%s = arith.constant 10.1 : f32
2442+
%t = vector.splat %s : vector<8x16xi32>
2443+
```
2444+
}];
2445+
2446+
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
2447+
"integer/index/float type">:$input);
2448+
let results = (outs AnyVectorOfAnyRank:$aggregate);
2449+
2450+
let builders = [
2451+
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
2452+
[{ build($_builder, $_state, aggregateType, element); }]>];
2453+
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
2454+
2455+
let hasFolder = 1;
2456+
}
2457+
24232458
//===----------------------------------------------------------------------===//
24242459
// VectorScaleOp
24252460
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/Attributes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class Attribute {
4949
template <typename U> bool isa() const;
5050
template <typename First, typename Second, typename... Rest>
5151
bool isa() const;
52+
template <typename First, typename... Rest>
53+
bool isa_and_nonnull() const;
5254
template <typename U> U dyn_cast() const;
5355
template <typename U> U dyn_cast_or_null() const;
5456
template <typename U> U cast() const;
@@ -114,6 +116,11 @@ bool Attribute::isa() const {
114116
return isa<First>() || isa<Second, Rest...>();
115117
}
116118

119+
template <typename First, typename... Rest>
120+
bool Attribute::isa_and_nonnull() const {
121+
return impl && isa<First, Rest...>();
122+
}
123+
117124
template <typename U> U Attribute::dyn_cast() const {
118125
return isa<U>() ? U(impl) : U(nullptr);
119126
}

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -663,99 +663,6 @@ struct SwitchOpLowering
663663
using Super::Super;
664664
};
665665

666-
// The Splat operation is lowered to an insertelement + a shufflevector
667-
// operation. Splat to only 0-d and 1-d vector result types are lowered.
668-
struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
669-
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
670-
671-
LogicalResult
672-
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
673-
ConversionPatternRewriter &rewriter) const override {
674-
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
675-
if (!resultType || resultType.getRank() > 1)
676-
return failure();
677-
678-
// First insert it into an undef vector so we can shuffle it.
679-
auto vectorType = typeConverter->convertType(splatOp.getType());
680-
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
681-
auto zero = rewriter.create<LLVM::ConstantOp>(
682-
splatOp.getLoc(),
683-
typeConverter->convertType(rewriter.getIntegerType(32)),
684-
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
685-
686-
// For 0-d vector, we simply do `insertelement`.
687-
if (resultType.getRank() == 0) {
688-
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
689-
splatOp, vectorType, undef, adaptor.getInput(), zero);
690-
return success();
691-
}
692-
693-
// For 1-d vector, we additionally do a `vectorshuffle`.
694-
auto v = rewriter.create<LLVM::InsertElementOp>(
695-
splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
696-
697-
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
698-
SmallVector<int32_t, 4> zeroValues(width, 0);
699-
700-
// Shuffle the value across the desired number of elements.
701-
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
702-
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
703-
zeroAttrs);
704-
return success();
705-
}
706-
};
707-
708-
// The Splat operation is lowered to an insertelement + a shufflevector
709-
// operation. Splat to only 2+-d vector result types are lowered by the
710-
// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
711-
struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
712-
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
713-
714-
LogicalResult
715-
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
716-
ConversionPatternRewriter &rewriter) const override {
717-
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
718-
if (!resultType || resultType.getRank() <= 1)
719-
return failure();
720-
721-
// First insert it into an undef vector so we can shuffle it.
722-
auto loc = splatOp.getLoc();
723-
auto vectorTypeInfo =
724-
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
725-
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
726-
auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
727-
if (!llvmNDVectorTy || !llvm1DVectorTy)
728-
return failure();
729-
730-
// Construct returned value.
731-
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
732-
733-
// Construct a 1-D vector with the splatted value that we insert in all the
734-
// places within the returned descriptor.
735-
Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
736-
auto zero = rewriter.create<LLVM::ConstantOp>(
737-
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
738-
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
739-
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
740-
adaptor.getInput(), zero);
741-
742-
// Shuffle the value across the desired number of elements.
743-
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
744-
SmallVector<int32_t, 4> zeroValues(width, 0);
745-
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
746-
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
747-
748-
// Iterate of linear index, convert to coords space and insert splatted 1-D
749-
// vector in each position.
750-
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
751-
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
752-
position);
753-
});
754-
rewriter.replaceOp(splatOp, desc);
755-
return success();
756-
}
757-
};
758-
759666
} // namespace
760667

761668
void mlir::populateStdToLLVMFuncOpConversionPattern(
@@ -779,8 +686,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
779686
ConstantOpLowering,
780687
ReturnOpLowering,
781688
SelectOpLowering,
782-
SplatOpLowering,
783-
SplatNdOpLowering,
784689
SwitchOpLowering>(converter);
785690
// clang-format on
786691
}

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,6 @@ class SelectOpPattern final : public OpConversionPattern<SelectOp> {
5555
ConversionPatternRewriter &rewriter) const override;
5656
};
5757

58-
/// Converts std.splat to spv.CompositeConstruct.
59-
class SplatPattern final : public OpConversionPattern<SplatOp> {
60-
public:
61-
using OpConversionPattern<SplatOp>::OpConversionPattern;
62-
63-
LogicalResult
64-
matchAndRewrite(SplatOp op, OpAdaptor adaptor,
65-
ConversionPatternRewriter &rewriter) const override;
66-
};
67-
6858
/// Converts std.br to spv.Branch.
6959
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
7060
using OpConversionPattern<BranchOp>::OpConversionPattern;
@@ -178,22 +168,6 @@ SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor,
178168
return success();
179169
}
180170

181-
//===----------------------------------------------------------------------===//
182-
// SplatOp
183-
//===----------------------------------------------------------------------===//
184-
185-
LogicalResult
186-
SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
187-
ConversionPatternRewriter &rewriter) const {
188-
auto dstVecType = op.getType().dyn_cast<VectorType>();
189-
if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
190-
return failure();
191-
SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.getInput());
192-
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
193-
source);
194-
return success();
195-
}
196-
197171
//===----------------------------------------------------------------------===//
198172
// BranchOpPattern
199173
//===----------------------------------------------------------------------===//
@@ -237,8 +211,8 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
237211
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
238212
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
239213

240-
ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
241-
CondBranchOpPattern>(typeConverter, context);
214+
ReturnOpPattern, SelectOpPattern, BranchOpPattern, CondBranchOpPattern>(
215+
typeConverter, context);
242216
}
243217

244218
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

0 commit comments

Comments
 (0)