@@ -43,9 +43,8 @@ using namespace mlir::linalg;
43
43
#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
44
44
#define LDBG (X ) LLVM_DEBUG(DBGS() << X)
45
45
46
- // / Try to vectorize `convOp` as a convolution.
47
- static FailureOr<Operation *> vectorizeConvolution (OpBuilder &b,
48
- LinalgOp convOp);
46
+ static FailureOr<Operation *>
47
+ vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp);
49
48
50
49
// / Return the unique instance of OpType in `block` if it is indeed unique.
51
50
// / Return null if none or more than 1 instances exist.
@@ -637,12 +636,13 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
637
636
SmallVector<Value> results;
638
637
// TODO: isaConvolutionOpInterface that can also infer from generic
639
638
// features. Will require stride/dilation attributes inference.
640
- FailureOr<Operation *> convOr = vectorizeConvolution (rewriter, linalgOp);
641
- if (succeeded (convOr)) {
639
+ if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation ())) {
640
+ LDBG (" Vectorize as a conv: " << linalgOp);
641
+ FailureOr<Operation *> convOr = vectorizeConvolution (rewriter, convOp);
642
+ if (failed (convOr))
643
+ return failure ();
642
644
llvm::append_range (results, (*convOr)->getResults ());
643
645
} else {
644
- if (failed (vectorizeLinalgOpPrecondition (linalgOp)))
645
- return failure ();
646
646
LDBG (" Vectorize generic by broadcasting to a common shape: " << linalgOp);
647
647
if (failed (vectorizeAsLinalgGeneric (rewriter, linalgOp, results)))
648
648
return failure ();
@@ -1098,6 +1098,134 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
1098
1098
patterns.getContext (), baseBenefit.getBenefit () + 1 );
1099
1099
}
1100
1100
1101
+ // TODO: cleanup all the convolution vectorization patterns.
1102
+ template <class ConvOp , int N>
1103
+ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1104
+ ConvOp op, PatternRewriter &rewriter) const {
1105
+ Location loc = op.getLoc ();
1106
+ MLIRContext *context = op.getContext ();
1107
+
1108
+ OpOperand *input = op.getInputOperand (0 );
1109
+ OpOperand *kernel = op.getInputOperand (1 );
1110
+ OpOperand *output = op.getOutputOperand (0 );
1111
+ ArrayRef<int64_t > inShape = op.getShape (input);
1112
+ ArrayRef<int64_t > kShape = op.getShape (kernel);
1113
+
1114
+ if (llvm::any_of (inShape, ShapedType::isDynamic) ||
1115
+ llvm::any_of (kShape , ShapedType::isDynamic))
1116
+ return failure ();
1117
+
1118
+ SmallVector<AffineExpr, 4 > mapping;
1119
+ SmallVector<int64_t , 4 > vectorDims;
1120
+ // Fail to apply when the size of not vectorized dimension is not 1.
1121
+ for (unsigned i = 0 ; i < N; i++) {
1122
+ if (!mask[i] && (inShape[i] != 1 || kShape [i] != 1 ))
1123
+ return failure ();
1124
+
1125
+ if (mask[i] && inShape[i] != kShape [i])
1126
+ return failure ();
1127
+
1128
+ if (mask[i]) {
1129
+ mapping.push_back (getAffineDimExpr (i, context));
1130
+ vectorDims.push_back (inShape[i]);
1131
+ }
1132
+ }
1133
+
1134
+ int64_t rank = op.getRank (input);
1135
+ int64_t numDims = mapping.size ();
1136
+ Type elemType = getElementTypeOrSelf (input->get ());
1137
+
1138
+ auto map = AffineMap::get (rank, 0 , mapping, context);
1139
+ SmallVector<Value, 4 > zeros (rank,
1140
+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
1141
+ auto vecType = VectorType::get (vectorDims, elemType);
1142
+
1143
+ auto inputVec = rewriter.create <vector::TransferReadOp>(
1144
+ loc, vecType, input->get (), zeros, map);
1145
+ auto kernelVec = rewriter.create <vector::TransferReadOp>(
1146
+ loc, vecType, kernel->get (), zeros, map);
1147
+
1148
+ auto acc = rewriter.create <arith::ConstantOp>(loc, elemType,
1149
+ rewriter.getZeroAttr (elemType));
1150
+
1151
+ std::array<AffineMap, 3 > indexingMaps{
1152
+ AffineMap::getMultiDimIdentityMap (numDims, context),
1153
+ AffineMap::getMultiDimIdentityMap (numDims, context),
1154
+ AffineMap::get (numDims, 0 , {}, context)};
1155
+
1156
+ std::vector<StringRef> iteratorTypes (numDims, " reduction" );
1157
+
1158
+ auto result = rewriter.create <vector::ContractionOp>(
1159
+ loc, inputVec, kernelVec, acc,
1160
+ rewriter.getAffineMapArrayAttr (indexingMaps),
1161
+ rewriter.getStrArrayAttr (iteratorTypes));
1162
+
1163
+ rewriter.create <memref::StoreOp>(loc, result, output->get (),
1164
+ ValueRange (zeros));
1165
+ rewriter.eraseOp (op);
1166
+ return success ();
1167
+ }
1168
+
1169
+ // / Inserts tiling, promotion and vectorization pattern for ConvOp
1170
+ // / conversion into corresponding pattern lists.
1171
+ template <typename ConvOp, unsigned N>
1172
+ static void populateVectorizationPatterns (
1173
+ RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1174
+ RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t > tileSizes) {
1175
+ auto *context = tilingPatterns.getContext ();
1176
+ if (tileSizes.size () < N)
1177
+ return ;
1178
+
1179
+ constexpr static StringRef kTiledMarker = " TILED" ;
1180
+ constexpr static StringRef kPromotedMarker = " PROMOTED" ;
1181
+ tilingPatterns.add <LinalgTilingPattern>(
1182
+ ConvOp::getOperationName (), context,
1183
+ LinalgTilingOptions ().setTileSizes (tileSizes),
1184
+ LinalgTransformationFilter (ArrayRef<StringAttr>{},
1185
+ StringAttr::get (context, kTiledMarker )));
1186
+
1187
+ promotionPatterns.add <LinalgPromotionPattern<ConvOp>>(
1188
+ context, LinalgPromotionOptions ().setUseFullTileBuffersByDefault (true ),
1189
+ LinalgTransformationFilter (StringAttr::get (context, kTiledMarker ),
1190
+ StringAttr::get (context, kPromotedMarker )));
1191
+
1192
+ SmallVector<bool , 4 > mask (N);
1193
+ int offset = tileSizes.size () - N;
1194
+ std::transform (tileSizes.begin () + offset, tileSizes.end (), mask.begin (),
1195
+ [](int64_t i) -> bool { return i > 1 ; });
1196
+
1197
+ vectorizationPatterns.add <ConvOpVectorization<ConvOp, N>>(context, mask);
1198
+ }
1199
+
1200
+ void mlir::linalg::populateConvVectorizationPatterns (
1201
+ MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1202
+ ArrayRef<int64_t > tileSizes) {
1203
+ RewritePatternSet tiling (context);
1204
+ RewritePatternSet promotion (context);
1205
+ RewritePatternSet vectorization (context);
1206
+ populateVectorizationPatterns<Conv1DOp, 1 >(tiling, promotion, vectorization,
1207
+ tileSizes);
1208
+
1209
+ populateVectorizationPatterns<Conv2DOp, 2 >(tiling, promotion, vectorization,
1210
+ tileSizes);
1211
+
1212
+ populateVectorizationPatterns<Conv3DOp, 3 >(tiling, promotion, vectorization,
1213
+ tileSizes);
1214
+
1215
+ populateVectorizationPatterns<Conv1DNwcWcfOp, 3 >(tiling, promotion,
1216
+ vectorization, tileSizes);
1217
+
1218
+ populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4 >(tiling, promotion,
1219
+ vectorization, tileSizes);
1220
+
1221
+ populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5 >(
1222
+ tiling, promotion, vectorization, tileSizes);
1223
+
1224
+ patterns.push_back (std::move (tiling));
1225
+ patterns.push_back (std::move (promotion));
1226
+ patterns.push_back (std::move (vectorization));
1227
+ }
1228
+
1101
1229
// ----------------------------------------------------------------------------//
1102
1230
// Forwarding patterns
1103
1231
// ----------------------------------------------------------------------------//
@@ -1640,39 +1768,40 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
1640
1768
};
1641
1769
} // namespace
1642
1770
1643
- // / Helper function to vectorize a LinalgOp with convolution semantics.
1771
+ // / Helper function to vectorize a `linalgOp` with convolution semantics.
1644
1772
// TODO: extend the generic vectorization to support windows and drop this.
1645
- static FailureOr<Operation *> vectorizeConvolution (OpBuilder &b, LinalgOp op) {
1646
- // The ConvolutionOpInterface gives us guarantees of existence for
1647
- // strides/dilations. However, we do not need to rely on those, we can simply
1648
- // use them if present, otherwise use the default and let the generic conv.
1649
- // matcher in the ConvGenerator succeed or fail.
1650
- auto strides = op->getAttrOfType <DenseIntElementsAttr>(" strides" );
1651
- auto dilations = op->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1773
+ static FailureOr<Operation *>
1774
+ vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp) {
1775
+ // TODO: these are legitimately part of ConvolutionOpInterface.
1776
+ auto strides = convOp->getAttrOfType <DenseIntElementsAttr>(" strides" );
1777
+ auto dilations = convOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1652
1778
auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
1653
1779
auto dilation = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
1654
- Conv1DNwcGenerator e (b, op, stride, dilation);
1780
+ LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation ());
1781
+ Conv1DNwcGenerator e (b, linalgOp, stride, dilation);
1655
1782
auto res = e.generateConv ();
1656
1783
if (succeeded (res))
1657
1784
return res;
1658
1785
return e.generateDilatedConv ();
1659
1786
}
1660
1787
1661
- struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
1788
+ struct VectorizeConvolution
1789
+ : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1662
1790
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
1663
1791
1664
- LogicalResult matchAndRewrite (LinalgOp op ,
1792
+ LogicalResult matchAndRewrite (ConvolutionOpInterface convOp ,
1665
1793
PatternRewriter &rewriter) const override {
1666
- FailureOr<Operation *> resultOrFail = vectorizeConvolution (rewriter, op);
1794
+ FailureOr<Operation *> resultOrFail =
1795
+ vectorizeConvolution (rewriter, convOp);
1667
1796
if (failed (resultOrFail))
1668
1797
return failure ();
1669
1798
Operation *newOp = *resultOrFail;
1670
1799
if (newOp->getNumResults () == 0 ) {
1671
- rewriter.eraseOp (op .getOperation ());
1800
+ rewriter.eraseOp (convOp .getOperation ());
1672
1801
return success ();
1673
1802
}
1674
1803
assert (newOp->getNumResults () == 1 && " expected single result" );
1675
- rewriter.replaceOp (op .getOperation (), newOp->getResult (0 ));
1804
+ rewriter.replaceOp (convOp .getOperation (), newOp->getResult (0 ));
1676
1805
return success ();
1677
1806
}
1678
1807
};
0 commit comments