Skip to content

Commit b828506

Browse files
[mlir][Linalg] Add a DownscaleDepthwiseConv2DNhwcHwcOp decomposition pattern.
Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D113907
1 parent b4e50e5 commit b828506

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,10 +904,83 @@ struct DownscaleSizeOneWindowed2DConvolution final
904904
};
905905
};
906906

907+
/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
908+
/// dimensions into 1-D depthwise convolution ops.
909+
struct DownscaleDepthwiseConv2DNhwcHwcOp final
910+
: public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
911+
using OpRewritePattern::OpRewritePattern;
912+
913+
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
914+
PatternRewriter &rewriter) const override {
915+
auto linalgOp = cast<linalg::LinalgOp>(*convOp);
916+
if (linalgOp.hasBufferSemantics())
917+
return failure(); // To be implemented
918+
919+
Value input = convOp.inputs().front();
920+
Value kernel = convOp.inputs().back();
921+
Value output = convOp.outputs().front();
922+
923+
auto inputType = input.getType().dyn_cast<RankedTensorType>();
924+
auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
925+
auto outputType = output.getType().dyn_cast<RankedTensorType>();
926+
927+
auto kernelShape = kernelType.getShape();
928+
auto outputShape = outputType.getShape();
929+
930+
// Only handle the case where at least one of the window dimensions is
931+
// of size 1. Other cases can rely on tiling to reduce to such cases.
932+
int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
933+
int64_t ohSize = outputShape[1], owSize = outputShape[2];
934+
bool removeH = (khSize == 1 && ohSize == 1);
935+
bool removeW = (kwSize == 1 && owSize == 1);
936+
if (!removeH && !removeW)
937+
return failure();
938+
939+
// Get new shapes and types for all operands by removing the size-1
940+
// dimension.
941+
using RTTBuilder = RankedTensorType::Builder;
942+
auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
943+
auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
944+
auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
945+
946+
// Rank-reduce operands.
947+
Location loc = convOp.getLoc();
948+
Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
949+
rewriter, loc, input, newInputType);
950+
Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
951+
rewriter, loc, kernel, newKernelType);
952+
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
953+
rewriter, loc, output, newOutputType);
954+
955+
// Rank-reduce strides and dilations too.
956+
// TODO: dropDim 1-liner helper.
957+
auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
958+
strides.erase(strides.begin() + (removeH ? 0 : 1));
959+
auto stridesAttr = rewriter.getI64VectorAttr(strides);
960+
961+
auto dilations =
962+
llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
963+
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
964+
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
965+
966+
auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
967+
loc, newOutputType, ValueRange{newInput, newKernel},
968+
ValueRange{newOutput}, stridesAttr, dilationsAttr);
969+
970+
// Insert back.
971+
Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
972+
rewriter, loc, conv1DOp.getResult(0), output);
973+
rewriter.replaceOp(convOp, inserted);
974+
975+
return success();
976+
};
977+
};
978+
907979
} // namespace
908980

909981
void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
910982
PatternBenefit benefit) {
911-
patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
912-
benefit);
983+
patterns.add<DownscaleSizeOneWindowed2DConvolution,
984+
DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
985+
benefit);
913986
}

mlir/test/Dialect/Linalg/decompose-convolution.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,27 @@ func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x
6868
outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
6969
return %0 : tensor<4x1x2x8xf32>
7070
}
71+
72+
// -----
73+
74+
// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
75+
func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>, %out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> {
76+
// CHECK: linalg.depthwise_conv_1d_nwc_wc
77+
%0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
78+
ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
79+
outs(%out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
80+
return %0: tensor<1x1x56x96xf32>
81+
}
82+
83+
// -----
84+
85+
// Do not convert convolution ops whose window dimensions are not ones.
86+
87+
// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
88+
func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>, %out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> {
89+
// CHECK: linalg.depthwise_conv_2d_nhwc_hwc
90+
%0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
91+
ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
92+
outs(%out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
93+
return %0: tensor<1x56x56x96xf32>
94+
}

0 commit comments

Comments
 (0)