@@ -43,8 +43,9 @@ using namespace mlir::linalg;
43
43
#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
44
44
#define LDBG (X ) LLVM_DEBUG(DBGS() << X)
45
45
46
- static FailureOr<Operation *>
47
- vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp);
46
+ // / Try to vectorize `convOp` as a convolution.
47
+ static FailureOr<Operation *> vectorizeConvolution (OpBuilder &b,
48
+ LinalgOp convOp);
48
49
49
50
// / Return the unique instance of OpType in `block` if it is indeed unique.
50
51
// / Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
636
637
SmallVector<Value> results;
637
638
// TODO: isaConvolutionOpInterface that can also infer from generic
638
639
// features. Will require stride/dilation attributes inference.
639
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation ())) {
640
- LDBG (" Vectorize as a conv: " << linalgOp);
641
- FailureOr<Operation *> convOr = vectorizeConvolution (rewriter, convOp);
642
- if (failed (convOr))
643
- return failure ();
640
+ FailureOr<Operation *> convOr = vectorizeConvolution (rewriter, linalgOp);
641
+ if (succeeded (convOr)) {
644
642
llvm::append_range (results, (*convOr)->getResults ());
645
643
} else {
644
+ if (failed (vectorizeLinalgOpPrecondition (linalgOp)))
645
+ return failure ();
646
646
LDBG (" Vectorize generic by broadcasting to a common shape: " << linalgOp);
647
647
if (failed (vectorizeAsLinalgGeneric (rewriter, linalgOp, results)))
648
648
return failure ();
@@ -1098,134 +1098,6 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
1098
1098
patterns.getContext (), baseBenefit.getBenefit () + 1 );
1099
1099
}
1100
1100
1101
- // TODO: cleanup all the convolution vectorization patterns.
1102
- template <class ConvOp , int N>
1103
- LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1104
- ConvOp op, PatternRewriter &rewriter) const {
1105
- Location loc = op.getLoc ();
1106
- MLIRContext *context = op.getContext ();
1107
-
1108
- OpOperand *input = op.getInputOperand (0 );
1109
- OpOperand *kernel = op.getInputOperand (1 );
1110
- OpOperand *output = op.getOutputOperand (0 );
1111
- ArrayRef<int64_t > inShape = op.getShape (input);
1112
- ArrayRef<int64_t > kShape = op.getShape (kernel);
1113
-
1114
- if (llvm::any_of (inShape, ShapedType::isDynamic) ||
1115
- llvm::any_of (kShape , ShapedType::isDynamic))
1116
- return failure ();
1117
-
1118
- SmallVector<AffineExpr, 4 > mapping;
1119
- SmallVector<int64_t , 4 > vectorDims;
1120
- // Fail to apply when the size of not vectorized dimension is not 1.
1121
- for (unsigned i = 0 ; i < N; i++) {
1122
- if (!mask[i] && (inShape[i] != 1 || kShape [i] != 1 ))
1123
- return failure ();
1124
-
1125
- if (mask[i] && inShape[i] != kShape [i])
1126
- return failure ();
1127
-
1128
- if (mask[i]) {
1129
- mapping.push_back (getAffineDimExpr (i, context));
1130
- vectorDims.push_back (inShape[i]);
1131
- }
1132
- }
1133
-
1134
- int64_t rank = op.getRank (input);
1135
- int64_t numDims = mapping.size ();
1136
- Type elemType = getElementTypeOrSelf (input->get ());
1137
-
1138
- auto map = AffineMap::get (rank, 0 , mapping, context);
1139
- SmallVector<Value, 4 > zeros (rank,
1140
- rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
1141
- auto vecType = VectorType::get (vectorDims, elemType);
1142
-
1143
- auto inputVec = rewriter.create <vector::TransferReadOp>(
1144
- loc, vecType, input->get (), zeros, map);
1145
- auto kernelVec = rewriter.create <vector::TransferReadOp>(
1146
- loc, vecType, kernel->get (), zeros, map);
1147
-
1148
- auto acc = rewriter.create <arith::ConstantOp>(loc, elemType,
1149
- rewriter.getZeroAttr (elemType));
1150
-
1151
- std::array<AffineMap, 3 > indexingMaps{
1152
- AffineMap::getMultiDimIdentityMap (numDims, context),
1153
- AffineMap::getMultiDimIdentityMap (numDims, context),
1154
- AffineMap::get (numDims, 0 , {}, context)};
1155
-
1156
- std::vector<StringRef> iteratorTypes (numDims, " reduction" );
1157
-
1158
- auto result = rewriter.create <vector::ContractionOp>(
1159
- loc, inputVec, kernelVec, acc,
1160
- rewriter.getAffineMapArrayAttr (indexingMaps),
1161
- rewriter.getStrArrayAttr (iteratorTypes));
1162
-
1163
- rewriter.create <memref::StoreOp>(loc, result, output->get (),
1164
- ValueRange (zeros));
1165
- rewriter.eraseOp (op);
1166
- return success ();
1167
- }
1168
-
1169
- // / Inserts tiling, promotion and vectorization pattern for ConvOp
1170
- // / conversion into corresponding pattern lists.
1171
- template <typename ConvOp, unsigned N>
1172
- static void populateVectorizationPatterns (
1173
- RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1174
- RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t > tileSizes) {
1175
- auto *context = tilingPatterns.getContext ();
1176
- if (tileSizes.size () < N)
1177
- return ;
1178
-
1179
- constexpr static StringRef kTiledMarker = " TILED" ;
1180
- constexpr static StringRef kPromotedMarker = " PROMOTED" ;
1181
- tilingPatterns.add <LinalgTilingPattern>(
1182
- ConvOp::getOperationName (), context,
1183
- LinalgTilingOptions ().setTileSizes (tileSizes),
1184
- LinalgTransformationFilter (ArrayRef<StringAttr>{},
1185
- StringAttr::get (context, kTiledMarker )));
1186
-
1187
- promotionPatterns.add <LinalgPromotionPattern<ConvOp>>(
1188
- context, LinalgPromotionOptions ().setUseFullTileBuffersByDefault (true ),
1189
- LinalgTransformationFilter (StringAttr::get (context, kTiledMarker ),
1190
- StringAttr::get (context, kPromotedMarker )));
1191
-
1192
- SmallVector<bool , 4 > mask (N);
1193
- int offset = tileSizes.size () - N;
1194
- std::transform (tileSizes.begin () + offset, tileSizes.end (), mask.begin (),
1195
- [](int64_t i) -> bool { return i > 1 ; });
1196
-
1197
- vectorizationPatterns.add <ConvOpVectorization<ConvOp, N>>(context, mask);
1198
- }
1199
-
1200
- void mlir::linalg::populateConvVectorizationPatterns (
1201
- MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1202
- ArrayRef<int64_t > tileSizes) {
1203
- RewritePatternSet tiling (context);
1204
- RewritePatternSet promotion (context);
1205
- RewritePatternSet vectorization (context);
1206
- populateVectorizationPatterns<Conv1DOp, 1 >(tiling, promotion, vectorization,
1207
- tileSizes);
1208
-
1209
- populateVectorizationPatterns<Conv2DOp, 2 >(tiling, promotion, vectorization,
1210
- tileSizes);
1211
-
1212
- populateVectorizationPatterns<Conv3DOp, 3 >(tiling, promotion, vectorization,
1213
- tileSizes);
1214
-
1215
- populateVectorizationPatterns<Conv1DNwcWcfOp, 3 >(tiling, promotion,
1216
- vectorization, tileSizes);
1217
-
1218
- populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4 >(tiling, promotion,
1219
- vectorization, tileSizes);
1220
-
1221
- populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5 >(
1222
- tiling, promotion, vectorization, tileSizes);
1223
-
1224
- patterns.push_back (std::move (tiling));
1225
- patterns.push_back (std::move (promotion));
1226
- patterns.push_back (std::move (vectorization));
1227
- }
1228
-
1229
1101
// ----------------------------------------------------------------------------//
1230
1102
// Forwarding patterns
1231
1103
// ----------------------------------------------------------------------------//
@@ -1754,40 +1626,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
1754
1626
};
1755
1627
} // namespace
1756
1628
1757
- // / Helper function to vectorize a `linalgOp` with convolution semantics.
1629
+ // / Helper function to vectorize a LinalgOp with convolution semantics.
1758
1630
// TODO: extend the generic vectorization to support windows and drop this.
1759
- static FailureOr<Operation *>
1760
- vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp) {
1761
- // TODO: these are legitimately part of ConvolutionOpInterface.
1762
- auto strides = convOp->getAttrOfType <DenseIntElementsAttr>(" strides" );
1763
- auto dilations = convOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1631
+ static FailureOr<Operation *> vectorizeConvolution (OpBuilder &b, LinalgOp op) {
1632
+ // The ConvolutionOpInterface gives us guarantees of existence for
1633
+ // strides/dilations. However, we do not need to rely on those, we can simply
1634
+ // use them if present, otherwise use the default and let the generic conv.
1635
+ // matcher in the ConvGenerator succeed or fail.
1636
+ auto strides = op->getAttrOfType <DenseIntElementsAttr>(" strides" );
1637
+ auto dilations = op->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1764
1638
auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
1765
1639
auto dilation = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
1766
- LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation ());
1767
- Conv1DNwcGenerator e (b, linalgOp, stride, dilation);
1640
+ Conv1DNwcGenerator e (b, op, stride, dilation);
1768
1641
auto res = e.generateConv ();
1769
1642
if (succeeded (res))
1770
1643
return res;
1771
1644
return e.generateDilatedConv ();
1772
1645
}
1773
1646
1774
- struct VectorizeConvolution
1775
- : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1647
+ struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
1776
1648
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
1777
1649
1778
- LogicalResult matchAndRewrite (ConvolutionOpInterface convOp ,
1650
+ LogicalResult matchAndRewrite (LinalgOp op ,
1779
1651
PatternRewriter &rewriter) const override {
1780
- FailureOr<Operation *> resultOrFail =
1781
- vectorizeConvolution (rewriter, convOp);
1652
+ FailureOr<Operation *> resultOrFail = vectorizeConvolution (rewriter, op);
1782
1653
if (failed (resultOrFail))
1783
1654
return failure ();
1784
1655
Operation *newOp = *resultOrFail;
1785
1656
if (newOp->getNumResults () == 0 ) {
1786
- rewriter.eraseOp (convOp .getOperation ());
1657
+ rewriter.eraseOp (op .getOperation ());
1787
1658
return success ();
1788
1659
}
1789
1660
assert (newOp->getNumResults () == 1 && " expected single result" );
1790
- rewriter.replaceOp (convOp .getOperation (), newOp->getResult (0 ));
1661
+ rewriter.replaceOp (op .getOperation (), newOp->getResult (0 ));
1791
1662
return success ();
1792
1663
}
1793
1664
};
0 commit comments