Skip to content

Commit c8f5735

Browse files
[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Differential Revision: https://reviews.llvm.org/D117323
1 parent 2164c54 commit c8f5735

File tree

6 files changed

+21
-354
lines changed

6 files changed

+21
-354
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,6 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
4646
//===----------------------------------------------------------------------===//
4747
using LinalgLoops = SmallVector<Operation *, 4>;
4848

49-
/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
50-
void populateConvVectorizationPatterns(
51-
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
52-
ArrayRef<int64_t> tileSizes);
53-
5449
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
5550
/// progressive lowering for convolution ops, it assume high-D convolution ops
5651
/// were decomposed previously.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 21 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ using namespace mlir::linalg;
4343
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
4444
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
4545

46-
static FailureOr<Operation *>
47-
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
46+
/// Try to vectorize `convOp` as a convolution.
47+
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
48+
LinalgOp convOp);
4849

4950
/// Return the unique instance of OpType in `block` if it is indeed unique.
5051
/// Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
636637
SmallVector<Value> results;
637638
// TODO: isaConvolutionOpInterface that can also infer from generic
638639
// features. Will require stride/dilation attributes inference.
639-
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
640-
LDBG("Vectorize as a conv: " << linalgOp);
641-
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
642-
if (failed(convOr))
643-
return failure();
640+
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
641+
if (succeeded(convOr)) {
644642
llvm::append_range(results, (*convOr)->getResults());
645643
} else {
644+
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
645+
return failure();
646646
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
647647
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
648648
return failure();
@@ -1098,134 +1098,6 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
10981098
patterns.getContext(), baseBenefit.getBenefit() + 1);
10991099
}
11001100

1101-
// TODO: cleanup all the convolution vectorization patterns.
1102-
template <class ConvOp, int N>
1103-
LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1104-
ConvOp op, PatternRewriter &rewriter) const {
1105-
Location loc = op.getLoc();
1106-
MLIRContext *context = op.getContext();
1107-
1108-
OpOperand *input = op.getInputOperand(0);
1109-
OpOperand *kernel = op.getInputOperand(1);
1110-
OpOperand *output = op.getOutputOperand(0);
1111-
ArrayRef<int64_t> inShape = op.getShape(input);
1112-
ArrayRef<int64_t> kShape = op.getShape(kernel);
1113-
1114-
if (llvm::any_of(inShape, ShapedType::isDynamic) ||
1115-
llvm::any_of(kShape, ShapedType::isDynamic))
1116-
return failure();
1117-
1118-
SmallVector<AffineExpr, 4> mapping;
1119-
SmallVector<int64_t, 4> vectorDims;
1120-
// Fail to apply when the size of not vectorized dimension is not 1.
1121-
for (unsigned i = 0; i < N; i++) {
1122-
if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
1123-
return failure();
1124-
1125-
if (mask[i] && inShape[i] != kShape[i])
1126-
return failure();
1127-
1128-
if (mask[i]) {
1129-
mapping.push_back(getAffineDimExpr(i, context));
1130-
vectorDims.push_back(inShape[i]);
1131-
}
1132-
}
1133-
1134-
int64_t rank = op.getRank(input);
1135-
int64_t numDims = mapping.size();
1136-
Type elemType = getElementTypeOrSelf(input->get());
1137-
1138-
auto map = AffineMap::get(rank, 0, mapping, context);
1139-
SmallVector<Value, 4> zeros(rank,
1140-
rewriter.create<arith::ConstantIndexOp>(loc, 0));
1141-
auto vecType = VectorType::get(vectorDims, elemType);
1142-
1143-
auto inputVec = rewriter.create<vector::TransferReadOp>(
1144-
loc, vecType, input->get(), zeros, map);
1145-
auto kernelVec = rewriter.create<vector::TransferReadOp>(
1146-
loc, vecType, kernel->get(), zeros, map);
1147-
1148-
auto acc = rewriter.create<arith::ConstantOp>(loc, elemType,
1149-
rewriter.getZeroAttr(elemType));
1150-
1151-
std::array<AffineMap, 3> indexingMaps{
1152-
AffineMap::getMultiDimIdentityMap(numDims, context),
1153-
AffineMap::getMultiDimIdentityMap(numDims, context),
1154-
AffineMap::get(numDims, 0, {}, context)};
1155-
1156-
std::vector<StringRef> iteratorTypes(numDims, "reduction");
1157-
1158-
auto result = rewriter.create<vector::ContractionOp>(
1159-
loc, inputVec, kernelVec, acc,
1160-
rewriter.getAffineMapArrayAttr(indexingMaps),
1161-
rewriter.getStrArrayAttr(iteratorTypes));
1162-
1163-
rewriter.create<memref::StoreOp>(loc, result, output->get(),
1164-
ValueRange(zeros));
1165-
rewriter.eraseOp(op);
1166-
return success();
1167-
}
1168-
1169-
/// Inserts tiling, promotion and vectorization pattern for ConvOp
1170-
/// conversion into corresponding pattern lists.
1171-
template <typename ConvOp, unsigned N>
1172-
static void populateVectorizationPatterns(
1173-
RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1174-
RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
1175-
auto *context = tilingPatterns.getContext();
1176-
if (tileSizes.size() < N)
1177-
return;
1178-
1179-
constexpr static StringRef kTiledMarker = "TILED";
1180-
constexpr static StringRef kPromotedMarker = "PROMOTED";
1181-
tilingPatterns.add<LinalgTilingPattern>(
1182-
ConvOp::getOperationName(), context,
1183-
LinalgTilingOptions().setTileSizes(tileSizes),
1184-
LinalgTransformationFilter(ArrayRef<StringAttr>{},
1185-
StringAttr::get(context, kTiledMarker)));
1186-
1187-
promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
1188-
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
1189-
LinalgTransformationFilter(StringAttr::get(context, kTiledMarker),
1190-
StringAttr::get(context, kPromotedMarker)));
1191-
1192-
SmallVector<bool, 4> mask(N);
1193-
int offset = tileSizes.size() - N;
1194-
std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
1195-
[](int64_t i) -> bool { return i > 1; });
1196-
1197-
vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
1198-
}
1199-
1200-
void mlir::linalg::populateConvVectorizationPatterns(
1201-
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1202-
ArrayRef<int64_t> tileSizes) {
1203-
RewritePatternSet tiling(context);
1204-
RewritePatternSet promotion(context);
1205-
RewritePatternSet vectorization(context);
1206-
populateVectorizationPatterns<Conv1DOp, 1>(tiling, promotion, vectorization,
1207-
tileSizes);
1208-
1209-
populateVectorizationPatterns<Conv2DOp, 2>(tiling, promotion, vectorization,
1210-
tileSizes);
1211-
1212-
populateVectorizationPatterns<Conv3DOp, 3>(tiling, promotion, vectorization,
1213-
tileSizes);
1214-
1215-
populateVectorizationPatterns<Conv1DNwcWcfOp, 3>(tiling, promotion,
1216-
vectorization, tileSizes);
1217-
1218-
populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4>(tiling, promotion,
1219-
vectorization, tileSizes);
1220-
1221-
populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5>(
1222-
tiling, promotion, vectorization, tileSizes);
1223-
1224-
patterns.push_back(std::move(tiling));
1225-
patterns.push_back(std::move(promotion));
1226-
patterns.push_back(std::move(vectorization));
1227-
}
1228-
12291101
//----------------------------------------------------------------------------//
12301102
// Forwarding patterns
12311103
//----------------------------------------------------------------------------//
@@ -1754,40 +1626,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
17541626
};
17551627
} // namespace
17561628

1757-
/// Helper function to vectorize a `linalgOp` with convolution semantics.
1629+
/// Helper function to vectorize a LinalgOp with convolution semantics.
17581630
// TODO: extend the generic vectorization to support windows and drop this.
1759-
static FailureOr<Operation *>
1760-
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
1761-
// TODO: these are legitimately part of ConvolutionOpInterface.
1762-
auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
1763-
auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
1631+
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
1632+
// The ConvolutionOpInterface gives us guarantees of existence for
1633+
// strides/dilations. However, we do not need to rely on those, we can simply
1634+
// use them if present, otherwise use the default and let the generic conv.
1635+
// matcher in the ConvGenerator succeed or fail.
1636+
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
1637+
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
17641638
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
17651639
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
1766-
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
1767-
Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
1640+
Conv1DNwcGenerator e(b, op, stride, dilation);
17681641
auto res = e.generateConv();
17691642
if (succeeded(res))
17701643
return res;
17711644
return e.generateDilatedConv();
17721645
}
17731646

1774-
struct VectorizeConvolution
1775-
: public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1647+
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
17761648
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
17771649

1778-
LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
1650+
LogicalResult matchAndRewrite(LinalgOp op,
17791651
PatternRewriter &rewriter) const override {
1780-
FailureOr<Operation *> resultOrFail =
1781-
vectorizeConvolution(rewriter, convOp);
1652+
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
17821653
if (failed(resultOrFail))
17831654
return failure();
17841655
Operation *newOp = *resultOrFail;
17851656
if (newOp->getNumResults() == 0) {
1786-
rewriter.eraseOp(convOp.getOperation());
1657+
rewriter.eraseOp(op.getOperation());
17871658
return success();
17881659
}
17891660
assert(newOp->getNumResults() == 1 && "expected single result");
1790-
rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
1661+
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
17911662
return success();
17921663
}
17931664
};

mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir

Lines changed: 0 additions & 53 deletions
This file was deleted.

mlir/test/lib/Dialect/Linalg/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRLinalgTestPasses
33
TestComprehensiveBufferize.cpp
4-
TestConvVectorization.cpp
54
TestLinalgCodegenStrategy.cpp
65
TestLinalgDistribution.cpp
76
TestLinalgElementwiseFusion.cpp

0 commit comments

Comments
 (0)