@@ -904,10 +904,83 @@ struct DownscaleSizeOneWindowed2DConvolution final
904
904
};
905
905
};
906
906
907
+ // / Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
908
+ // / dimensions into 1-D depthwise convolution ops.
909
+ struct DownscaleDepthwiseConv2DNhwcHwcOp final
910
+ : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
911
+ using OpRewritePattern::OpRewritePattern;
912
+
913
+ LogicalResult matchAndRewrite (DepthwiseConv2DNhwcHwcOp convOp,
914
+ PatternRewriter &rewriter) const override {
915
+ auto linalgOp = cast<linalg::LinalgOp>(*convOp);
916
+ if (linalgOp.hasBufferSemantics ())
917
+ return failure (); // To be implemented
918
+
919
+ Value input = convOp.inputs ().front ();
920
+ Value kernel = convOp.inputs ().back ();
921
+ Value output = convOp.outputs ().front ();
922
+
923
+ auto inputType = input.getType ().dyn_cast <RankedTensorType>();
924
+ auto kernelType = kernel.getType ().dyn_cast <RankedTensorType>();
925
+ auto outputType = output.getType ().dyn_cast <RankedTensorType>();
926
+
927
+ auto kernelShape = kernelType.getShape ();
928
+ auto outputShape = outputType.getShape ();
929
+
930
+ // Only handle the case where at least one of the window dimensions is
931
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
932
+ int64_t khSize = kernelShape[0 ], kwSize = kernelShape[1 ];
933
+ int64_t ohSize = outputShape[1 ], owSize = outputShape[2 ];
934
+ bool removeH = (khSize == 1 && ohSize == 1 );
935
+ bool removeW = (kwSize == 1 && owSize == 1 );
936
+ if (!removeH && !removeW)
937
+ return failure ();
938
+
939
+ // Get new shapes and types for all operands by removing the size-1
940
+ // dimension.
941
+ using RTTBuilder = RankedTensorType::Builder;
942
+ auto newInputType = RTTBuilder (inputType).dropDim ((removeH ? 1 : 2 ));
943
+ auto newKernelType = RTTBuilder (kernelType).dropDim ((removeH ? 0 : 1 ));
944
+ auto newOutputType = RTTBuilder (outputType).dropDim (removeH ? 1 : 2 );
945
+
946
+ // Rank-reduce operands.
947
+ Location loc = convOp.getLoc ();
948
+ Value newInput = tensor::createCanonicalRankReducingExtractSliceOp (
949
+ rewriter, loc, input, newInputType);
950
+ Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp (
951
+ rewriter, loc, kernel, newKernelType);
952
+ Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp (
953
+ rewriter, loc, output, newOutputType);
954
+
955
+ // Rank-reduce strides and dilations too.
956
+ // TODO: dropDim 1-liner helper.
957
+ auto strides = llvm::to_vector<4 >(convOp.strides ().getValues <int64_t >());
958
+ strides.erase (strides.begin () + (removeH ? 0 : 1 ));
959
+ auto stridesAttr = rewriter.getI64VectorAttr (strides);
960
+
961
+ auto dilations =
962
+ llvm::to_vector<4 >(convOp.dilations ().getValues <int64_t >());
963
+ dilations.erase (dilations.begin () + (removeH ? 0 : 1 ));
964
+ auto dilationsAttr = rewriter.getI64VectorAttr (dilations);
965
+
966
+ auto conv1DOp = rewriter.create <DepthwiseConv1DNwcWcOp>(
967
+ loc, newOutputType, ValueRange{newInput, newKernel},
968
+ ValueRange{newOutput}, stridesAttr, dilationsAttr);
969
+
970
+ // Insert back.
971
+ Value inserted = tensor::createCanonicalRankReducingInsertSliceOp (
972
+ rewriter, loc, conv1DOp.getResult (0 ), output);
973
+ rewriter.replaceOp (convOp, inserted);
974
+
975
+ return success ();
976
+ };
977
+ };
978
+
907
979
} // namespace
908
980
909
981
void linalg::populateDecomposeConvolutionPatterns (RewritePatternSet &patterns,
910
982
PatternBenefit benefit) {
911
- patterns.add <DownscaleSizeOneWindowed2DConvolution>(patterns.getContext (),
912
- benefit);
983
+ patterns.add <DownscaleSizeOneWindowed2DConvolution,
984
+ DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext (),
985
+ benefit);
913
986
}
0 commit comments