Skip to content

Commit 7b615a8

Browse files
committed
[mlir][linalg] Rewrite linalg.conv_2d_nhwc_hwcf into 1-D
We'd like to take a progressive approach towards Fconvolution op CodeGen, by 1) tiling it to fit compute hierarchy first, and then 2) tiling along window dimensions with size 1 to reduce the problem to be matmul-like. After that, we can 3) downscale high-D convolution ops to low-D by removing the size-1 window dimensions. The final step would be 4) vectorizing the low-D convolution op directly. We have patterns for 1), 2), and 4). This commit adds a pattern for 3) for `linalg.conv_2d_nhwc_hwcf` ops as a starter. Supporting other high-D convolution ops should be similar and mechanical. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D112928
1 parent 67887b0 commit 7b615a8

File tree

4 files changed

+184
-1
lines changed

4 files changed

+184
-1
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ void populateConvVectorizationPatterns(
4646
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
4747
ArrayRef<int64_t> tileSizes);
4848

49-
/// Populates patterns for vectorizing convolution ops.
49+
/// Populates patterns to decompose high-D convolution ops into low-D ones. This
50+
/// is a step in progressive lowering for convolution ops, afterwards we can
51+
/// vectorize the low-D convolution ops.
52+
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
53+
PatternBenefit benefit = 1);
54+
55+
/// Populates patterns for vectorizing low-D convolution ops. This is a step in
56+
/// progressive lowering for convolution ops, it assume high-D convolution ops
57+
/// were decomposed previously.
5058
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
5159
PatternBenefit benefit = 1);
5260

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,3 +840,98 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
840840
rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
841841
return success();
842842
}
843+
844+
namespace {
845+
// The following are patterns for downscaling convolution ops with size-1
846+
// window dimensions.
847+
//
848+
// Note that we'd eventually want to write such transformations in a generic
849+
// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
850+
// and then turning back to named ops. But for now it's fine to have a few
851+
// patterns matching special ops to get started.
852+
853+
/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
854+
/// convolution ops.
855+
struct DownscaleSizeOneWindowed2DConvolution final
856+
: public OpRewritePattern<Conv2DNhwcHwcfOp> {
857+
using OpRewritePattern::OpRewritePattern;
858+
859+
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
860+
PatternRewriter &rewriter) const override {
861+
auto linalgOp = cast<linalg::LinalgOp>(*convOp);
862+
if (linalgOp.hasBufferSemantics())
863+
return failure(); // To be implemented
864+
865+
Value input = convOp.inputs().front();
866+
Value filter = convOp.inputs().back();
867+
Value output = convOp.outputs().front();
868+
869+
auto inputType = input.getType().dyn_cast<RankedTensorType>();
870+
auto filterType = filter.getType().dyn_cast<RankedTensorType>();
871+
auto outputType = output.getType().dyn_cast<RankedTensorType>();
872+
873+
auto inputShape = inputType.getShape();
874+
auto filterShape = filterType.getShape();
875+
auto outputShape = outputType.getShape();
876+
877+
// Only handle the case where at least one of the window dimensions is
878+
// of size 1. Other cases can rely on tiling to reduce to such cases.
879+
int64_t fhSize = filterShape[0], fwSize = filterShape[1];
880+
int64_t ohSize = outputShape[1], owSize = outputShape[2];
881+
if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
882+
return failure();
883+
bool removeH = ohSize == 1;
884+
885+
// Get new shapes and types for all operands by removing the size-1
886+
// dimension.
887+
888+
SmallVector<int64_t, 3> newInputShape{
889+
inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
890+
auto newInputType = RankedTensorType::get(
891+
newInputShape, inputType.getElementType(), inputType.getEncoding());
892+
893+
SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
894+
filterShape[2], filterShape[3]};
895+
auto newFilterType = RankedTensorType::get(
896+
newFilterShape, filterType.getElementType(), filterType.getEncoding());
897+
898+
SmallVector<int64_t, 3> newOutputShape{
899+
outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
900+
auto newOutputType = RankedTensorType::get(
901+
newOutputShape, outputType.getElementType(), outputType.getEncoding());
902+
903+
SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
904+
SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
905+
906+
// Reshape all operands for 1-D convolution.
907+
Location loc = convOp.getLoc();
908+
Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
909+
loc, newInputType, input, ioReshapeIndices);
910+
Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
911+
loc, newFilterType, filter, fReshapeIndices);
912+
Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
913+
loc, newOutputType, output, ioReshapeIndices);
914+
915+
// We need to shrink the strides and dilations too.
916+
auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0);
917+
auto stridesAttr = rewriter.getI64VectorAttr(stride);
918+
auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0);
919+
auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
920+
921+
auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
922+
loc, newOutputType, ValueRange{newInput, newFilter},
923+
ValueRange{newOutput}, stridesAttr, dilationsAttr);
924+
925+
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
926+
convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
927+
return success();
928+
};
929+
};
930+
931+
} // namespace
932+
933+
void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
934+
PatternBenefit benefit) {
935+
patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
936+
benefit);
937+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-decompose-convolution-patterns %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor
4+
// CHECK-SAME: (%[[INPUT:.+]]: tensor<4x1x6x3xf32>, %[[FILTER:.+]]: tensor<1x2x3x8xf32>, %[[INIT:.+]]: tensor<4x1x2x8xf32>)
5+
func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> {
6+
%0 = linalg.conv_2d_nhwc_hwcf
7+
{dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>}
8+
ins(%input, %filter : tensor<4x1x6x3xf32>, tensor<1x2x3x8xf32>)
9+
outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
10+
return %0 : tensor<4x1x2x8xf32>
11+
}
12+
13+
// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
14+
// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32>
15+
// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
16+
// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32>
17+
// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
18+
// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32>
19+
// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
20+
// CHECK-SAME: dilations = dense<3> : vector<1xi64>
21+
// CHECK-SAME: strides = dense<2> : vector<1xi64>
22+
// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>)
23+
// CHECK-SAME: outs(%[[INIT_1D]] : tensor<4x2x8xf32>)
24+
// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
25+
// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
26+
// CHECK: return %[[CONV_2D]]
27+
28+
// -----
29+
30+
// CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor
31+
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x1x?xf32>, %[[FILTER:.+]]: tensor<?x1x?x?xf32>, %[[INIT:.+]]: tensor<?x?x1x?xf32>)
32+
func @conv2d_nhwc_qxqx1xq_tensor(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x1x?x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
33+
%0 = linalg.conv_2d_nhwc_hwcf
34+
{dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>}
35+
ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x1x?x?xf32>)
36+
outs(%init : tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
37+
return %0 : tensor<?x?x1x?xf32>
38+
}
39+
40+
// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
41+
// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
42+
// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
43+
// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<?x1x?x?xf32> into tensor<?x?x?xf32>
44+
// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
45+
// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
46+
// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
47+
// CHECK-SAME: dilations = dense<2> : vector<1xi64>
48+
// CHECK-SAME: strides = dense<3> : vector<1xi64>
49+
// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
50+
// CHECK-SAME: outs(%[[INIT_1D]] : tensor<?x?x?xf32>)
51+
// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
52+
// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
53+
// CHECK: return %[[CONV_2D]]
54+
55+
// -----
56+
57+
// Do not convert convolution ops whose window dimensions are not ones.
58+
59+
// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor
60+
func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> {
61+
// CHECK: linalg.conv_2d_nhwc_hwcf
62+
%0 = linalg.conv_2d_nhwc_hwcf
63+
{dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
64+
ins(%input, %filter : tensor<4x3x5x3xf32>, tensor<2x2x3x8xf32>)
65+
outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
66+
return %0 : tensor<4x1x2x8xf32>
67+
}

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ struct TestLinalgTransforms
152152
llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
153153
"tiled_loop"),
154154
llvm::cl::init("for")};
155+
Option<bool> testDecomposeConvolutionPattern{
156+
*this, "test-decompose-convolution-patterns",
157+
llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops "
158+
"into low-D ones"),
159+
llvm::cl::init(false)};
155160
};
156161
} // end anonymous namespace
157162

@@ -576,6 +581,12 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
576581
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
577582
}
578583

584+
static void applyDecomposeConvolutionPatterns(FuncOp funcOp) {
585+
RewritePatternSet patterns(funcOp.getContext());
586+
populateDecomposeConvolutionPatterns(patterns);
587+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
588+
}
589+
579590
static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
580591
RewritePatternSet patterns(funcOp.getContext());
581592
patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
@@ -819,6 +830,8 @@ void TestLinalgTransforms::runOnFunction() {
819830
return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
820831
if (testInterchangePattern.hasValue())
821832
return applyInterchangePattern(getFunction(), testInterchangePattern);
833+
if (testDecomposeConvolutionPattern)
834+
return applyDecomposeConvolutionPatterns(getFunction());
822835
}
823836

824837
namespace mlir {

0 commit comments

Comments
 (0)