Skip to content

Commit f40a579

Browse files
nicolasvasilachejoker-eph
authored andcommitted
Revert "[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface"
This reverts commit c8f5735. The integration tests are broken.
1 parent caf5548 commit f40a579

File tree

6 files changed

+354
-21
lines changed

6 files changed

+354
-21
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ using LinalgLoops = SmallVector<Operation *, 4>;
4949
void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
5050
const LinalgTilingOptions &options);
5151

52+
/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
53+
void populateConvVectorizationPatterns(
54+
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
55+
ArrayRef<int64_t> tileSizes);
56+
5257
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
5358
/// progressive lowering for convolution ops, it assume high-D convolution ops
5459
/// were decomposed previously.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 150 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ using namespace mlir::linalg;
4343
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
4444
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
4545

46-
/// Try to vectorize `convOp` as a convolution.
47-
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
48-
LinalgOp convOp);
46+
static FailureOr<Operation *>
47+
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
4948

5049
/// Return the unique instance of OpType in `block` if it is indeed unique.
5150
/// Return null if none or more than 1 instances exist.
@@ -637,12 +636,13 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
637636
SmallVector<Value> results;
638637
// TODO: isaConvolutionOpInterface that can also infer from generic
639638
// features. Will require stride/dilation attributes inference.
640-
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
641-
if (succeeded(convOr)) {
639+
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
640+
LDBG("Vectorize as a conv: " << linalgOp);
641+
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
642+
if (failed(convOr))
643+
return failure();
642644
llvm::append_range(results, (*convOr)->getResults());
643645
} else {
644-
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
645-
return failure();
646646
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
647647
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
648648
return failure();
@@ -1098,6 +1098,134 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
10981098
patterns.getContext(), baseBenefit.getBenefit() + 1);
10991099
}
11001100

1101+
// TODO: cleanup all the convolution vectorization patterns.
1102+
template <class ConvOp, int N>
1103+
LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1104+
ConvOp op, PatternRewriter &rewriter) const {
1105+
Location loc = op.getLoc();
1106+
MLIRContext *context = op.getContext();
1107+
1108+
OpOperand *input = op.getInputOperand(0);
1109+
OpOperand *kernel = op.getInputOperand(1);
1110+
OpOperand *output = op.getOutputOperand(0);
1111+
ArrayRef<int64_t> inShape = op.getShape(input);
1112+
ArrayRef<int64_t> kShape = op.getShape(kernel);
1113+
1114+
if (llvm::any_of(inShape, ShapedType::isDynamic) ||
1115+
llvm::any_of(kShape, ShapedType::isDynamic))
1116+
return failure();
1117+
1118+
SmallVector<AffineExpr, 4> mapping;
1119+
SmallVector<int64_t, 4> vectorDims;
1120+
// Fail to apply when the size of not vectorized dimension is not 1.
1121+
for (unsigned i = 0; i < N; i++) {
1122+
if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
1123+
return failure();
1124+
1125+
if (mask[i] && inShape[i] != kShape[i])
1126+
return failure();
1127+
1128+
if (mask[i]) {
1129+
mapping.push_back(getAffineDimExpr(i, context));
1130+
vectorDims.push_back(inShape[i]);
1131+
}
1132+
}
1133+
1134+
int64_t rank = op.getRank(input);
1135+
int64_t numDims = mapping.size();
1136+
Type elemType = getElementTypeOrSelf(input->get());
1137+
1138+
auto map = AffineMap::get(rank, 0, mapping, context);
1139+
SmallVector<Value, 4> zeros(rank,
1140+
rewriter.create<arith::ConstantIndexOp>(loc, 0));
1141+
auto vecType = VectorType::get(vectorDims, elemType);
1142+
1143+
auto inputVec = rewriter.create<vector::TransferReadOp>(
1144+
loc, vecType, input->get(), zeros, map);
1145+
auto kernelVec = rewriter.create<vector::TransferReadOp>(
1146+
loc, vecType, kernel->get(), zeros, map);
1147+
1148+
auto acc = rewriter.create<arith::ConstantOp>(loc, elemType,
1149+
rewriter.getZeroAttr(elemType));
1150+
1151+
std::array<AffineMap, 3> indexingMaps{
1152+
AffineMap::getMultiDimIdentityMap(numDims, context),
1153+
AffineMap::getMultiDimIdentityMap(numDims, context),
1154+
AffineMap::get(numDims, 0, {}, context)};
1155+
1156+
std::vector<StringRef> iteratorTypes(numDims, "reduction");
1157+
1158+
auto result = rewriter.create<vector::ContractionOp>(
1159+
loc, inputVec, kernelVec, acc,
1160+
rewriter.getAffineMapArrayAttr(indexingMaps),
1161+
rewriter.getStrArrayAttr(iteratorTypes));
1162+
1163+
rewriter.create<memref::StoreOp>(loc, result, output->get(),
1164+
ValueRange(zeros));
1165+
rewriter.eraseOp(op);
1166+
return success();
1167+
}
1168+
1169+
/// Inserts tiling, promotion and vectorization pattern for ConvOp
1170+
/// conversion into corresponding pattern lists.
1171+
template <typename ConvOp, unsigned N>
1172+
static void populateVectorizationPatterns(
1173+
RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1174+
RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
1175+
auto *context = tilingPatterns.getContext();
1176+
if (tileSizes.size() < N)
1177+
return;
1178+
1179+
constexpr static StringRef kTiledMarker = "TILED";
1180+
constexpr static StringRef kPromotedMarker = "PROMOTED";
1181+
tilingPatterns.add<LinalgTilingPattern>(
1182+
ConvOp::getOperationName(), context,
1183+
LinalgTilingOptions().setTileSizes(tileSizes),
1184+
LinalgTransformationFilter(ArrayRef<StringAttr>{},
1185+
StringAttr::get(context, kTiledMarker)));
1186+
1187+
promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
1188+
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
1189+
LinalgTransformationFilter(StringAttr::get(context, kTiledMarker),
1190+
StringAttr::get(context, kPromotedMarker)));
1191+
1192+
SmallVector<bool, 4> mask(N);
1193+
int offset = tileSizes.size() - N;
1194+
std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
1195+
[](int64_t i) -> bool { return i > 1; });
1196+
1197+
vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
1198+
}
1199+
1200+
void mlir::linalg::populateConvVectorizationPatterns(
1201+
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1202+
ArrayRef<int64_t> tileSizes) {
1203+
RewritePatternSet tiling(context);
1204+
RewritePatternSet promotion(context);
1205+
RewritePatternSet vectorization(context);
1206+
populateVectorizationPatterns<Conv1DOp, 1>(tiling, promotion, vectorization,
1207+
tileSizes);
1208+
1209+
populateVectorizationPatterns<Conv2DOp, 2>(tiling, promotion, vectorization,
1210+
tileSizes);
1211+
1212+
populateVectorizationPatterns<Conv3DOp, 3>(tiling, promotion, vectorization,
1213+
tileSizes);
1214+
1215+
populateVectorizationPatterns<Conv1DNwcWcfOp, 3>(tiling, promotion,
1216+
vectorization, tileSizes);
1217+
1218+
populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4>(tiling, promotion,
1219+
vectorization, tileSizes);
1220+
1221+
populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5>(
1222+
tiling, promotion, vectorization, tileSizes);
1223+
1224+
patterns.push_back(std::move(tiling));
1225+
patterns.push_back(std::move(promotion));
1226+
patterns.push_back(std::move(vectorization));
1227+
}
1228+
11011229
//----------------------------------------------------------------------------//
11021230
// Forwarding patterns
11031231
//----------------------------------------------------------------------------//
@@ -1640,39 +1768,40 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
16401768
};
16411769
} // namespace
16421770

1643-
/// Helper function to vectorize a LinalgOp with convolution semantics.
1771+
/// Helper function to vectorize a `linalgOp` with convolution semantics.
16441772
// TODO: extend the generic vectorization to support windows and drop this.
1645-
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
1646-
// The ConvolutionOpInterface gives us guarantees of existence for
1647-
// strides/dilations. However, we do not need to rely on those, we can simply
1648-
// use them if present, otherwise use the default and let the generic conv.
1649-
// matcher in the ConvGenerator succeed or fail.
1650-
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
1651-
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
1773+
static FailureOr<Operation *>
1774+
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
1775+
// TODO: these are legitimately part of ConvolutionOpInterface.
1776+
auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
1777+
auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
16521778
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
16531779
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
1654-
Conv1DNwcGenerator e(b, op, stride, dilation);
1780+
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
1781+
Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
16551782
auto res = e.generateConv();
16561783
if (succeeded(res))
16571784
return res;
16581785
return e.generateDilatedConv();
16591786
}
16601787

1661-
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
1788+
struct VectorizeConvolution
1789+
: public OpInterfaceRewritePattern<ConvolutionOpInterface> {
16621790
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
16631791

1664-
LogicalResult matchAndRewrite(LinalgOp op,
1792+
LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
16651793
PatternRewriter &rewriter) const override {
1666-
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
1794+
FailureOr<Operation *> resultOrFail =
1795+
vectorizeConvolution(rewriter, convOp);
16671796
if (failed(resultOrFail))
16681797
return failure();
16691798
Operation *newOp = *resultOrFail;
16701799
if (newOp->getNumResults() == 0) {
1671-
rewriter.eraseOp(op.getOperation());
1800+
rewriter.eraseOp(convOp.getOperation());
16721801
return success();
16731802
}
16741803
assert(newOp->getNumResults() == 1 && "expected single result");
1675-
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
1804+
rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
16761805
return success();
16771806
}
16781807
};
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file
2+
// | FileCheck %s
3+
4+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)>
5+
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
6+
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
7+
// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)>
8+
// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
9+
10+
// CHECK-LABEL: @conv_1d
11+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
12+
// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
13+
// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32
14+
func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
15+
// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
16+
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
17+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
18+
// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
19+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
20+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
21+
// CHECK: %[[v0:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?xf32>
22+
// CHECK: %[[v1:.*]] = memref.dim %[[arg2]], %[[c0]] : memref<?xf32>
23+
// CHECK: %[[v2:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf32>
24+
// CHECK: %[[v3:.*]] = memref.alloc(%[[c12]]) : memref<?xi8>
25+
// CHECK: %[[v4:.*]] = memref.alloc(%[[c12]]) : memref<?xi8>
26+
// CHECK: %[[v5:.*]] = memref.alloc(%[[c4]]) : memref<?xi8>
27+
// CHECK: %[[v6:.*]] = memref.view %[[v3]][%[[c0]]][] : memref<?xi8> to memref<3xf32>
28+
// CHECK: %[[v7:.*]] = memref.view %[[v4]][%[[c0]]][] : memref<?xi8> to memref<3xf32>
29+
// CHECK: %[[v8:.*]] = memref.view %[[v5]][%[[c0]]][] : memref<?xi8> to memref<1xf32>
30+
// CHECK: scf.for %[[arg3:.*]] = %[[c0]] to %[[v1]] step %[[c1]] {
31+
// CHECK: %[[v9:.*]] = affine.min #[[$map0]](%[[arg3]])[%[[v1]]]
32+
// CHECK: %[[v10:.*]] = subview %[[arg2]][%[[arg3]]] [%[[v9]]] [1] : memref<?xf32> to memref<?xf32, #[[$map1]]>
33+
// CHECK: %[[v11:.*]] = subview %[[v8]][0] [%[[v9]]] [1] : memref<1xf32> to memref<?xf32>
34+
// CHECK: scf.for %[[arg4:.*]] = %[[c0]] to %[[v0]] step %[[c3]] {
35+
// CHECK: %[[v12:.*]] = affine.apply #[[$map2]](%[[arg3]], %[[arg4]])
36+
// CHECK: %[[v13:.*]] = affine.min #[[$map3]](%[[arg3]], %[[arg4]])[%[[v2]]]
37+
// CHECK: %[[v14:.*]] = subview %arg0[%12] [%13] [1] : memref<?xf32> to memref<?xf32, #[[$map1]]>
38+
// CHECK: %[[v15:.*]] = affine.min #[[$map4]](%arg4)[%0]
39+
// CHECK: %[[v16:.*]] = subview %[[arg1]][%[[arg4]]] [%[[v15]]] [1] : memref<?xf32> to memref<?xf32, #[[$map1]]>
40+
// CHECK: %[[v17:.*]] = subview %[[v6]][0] [%[[v13]]] [1] : memref<3xf32> to memref<?xf32>
41+
// CHECK: %[[v19:.*]] = vector.transfer_read %[[v6]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32>
42+
// CHECK: %[[v20:.*]] = vector.transfer_read %[[v7]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32>
43+
// CHECK: %[[v21:.*]] = arith.mulf %[[v19]], %[[v20]] : vector<3xf32>
44+
// CHECK: %[[v22:.*]] = vector.reduction "add", %[[v21]], %[[cst]] : vector<3xf32> into f32
45+
// CHECK: store %[[v22]], %[[v8]][%[[c0]]] : memref<1xf32>
46+
// CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] {
47+
// CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref<?xf32>
48+
// CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref<?xf32, #[[$map1]]>
49+
linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
50+
outs(%arg2 : memref<?xf32>)
51+
return
52+
}
53+

mlir/test/lib/Dialect/Linalg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRLinalgTestPasses
33
TestComprehensiveBufferize.cpp
4+
TestConvVectorization.cpp
45
TestLinalgCodegenStrategy.cpp
56
TestLinalgDistribution.cpp
67
TestLinalgElementwiseFusion.cpp

0 commit comments

Comments
 (0)