@@ -840,3 +840,98 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
840
840
rewriter.replaceOp (sliceOp, tiledPadOp->getResults ());
841
841
return success ();
842
842
}
843
+
844
+ namespace {
845
+ // The following are patterns for downscaling convolution ops with size-1
846
+ // window dimensions.
847
+ //
848
+ // Note that we'd eventually want to write such transformations in a generic
849
+ // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
850
+ // and then turning back to named ops. But for now it's fine to have a few
851
+ // patterns matching special ops to get started.
852
+
853
+ // / Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
854
+ // / convolution ops.
855
+ struct DownscaleSizeOneWindowed2DConvolution final
856
+ : public OpRewritePattern<Conv2DNhwcHwcfOp> {
857
+ using OpRewritePattern::OpRewritePattern;
858
+
859
+ LogicalResult matchAndRewrite (linalg::Conv2DNhwcHwcfOp convOp,
860
+ PatternRewriter &rewriter) const override {
861
+ auto linalgOp = cast<linalg::LinalgOp>(*convOp);
862
+ if (linalgOp.hasBufferSemantics ())
863
+ return failure (); // To be implemented
864
+
865
+ Value input = convOp.inputs ().front ();
866
+ Value filter = convOp.inputs ().back ();
867
+ Value output = convOp.outputs ().front ();
868
+
869
+ auto inputType = input.getType ().dyn_cast <RankedTensorType>();
870
+ auto filterType = filter.getType ().dyn_cast <RankedTensorType>();
871
+ auto outputType = output.getType ().dyn_cast <RankedTensorType>();
872
+
873
+ auto inputShape = inputType.getShape ();
874
+ auto filterShape = filterType.getShape ();
875
+ auto outputShape = outputType.getShape ();
876
+
877
+ // Only handle the case where at least one of the window dimensions is
878
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
879
+ int64_t fhSize = filterShape[0 ], fwSize = filterShape[1 ];
880
+ int64_t ohSize = outputShape[1 ], owSize = outputShape[2 ];
881
+ if (!(fhSize == 1 && ohSize == 1 ) && !(fwSize == 1 && owSize == 1 ))
882
+ return failure ();
883
+ bool removeH = ohSize == 1 ;
884
+
885
+ // Get new shapes and types for all operands by removing the size-1
886
+ // dimension.
887
+
888
+ SmallVector<int64_t , 3 > newInputShape{
889
+ inputShape[0 ], inputShape[removeH ? 2 : 1 ], inputShape[3 ]};
890
+ auto newInputType = RankedTensorType::get (
891
+ newInputShape, inputType.getElementType (), inputType.getEncoding ());
892
+
893
+ SmallVector<int64_t , 3 > newFilterShape{filterShape[removeH ? 1 : 0 ],
894
+ filterShape[2 ], filterShape[3 ]};
895
+ auto newFilterType = RankedTensorType::get (
896
+ newFilterShape, filterType.getElementType (), filterType.getEncoding ());
897
+
898
+ SmallVector<int64_t , 3 > newOutputShape{
899
+ outputShape[0 ], outputShape[removeH ? 2 : 1 ], outputShape[3 ]};
900
+ auto newOutputType = RankedTensorType::get (
901
+ newOutputShape, outputType.getElementType (), outputType.getEncoding ());
902
+
903
+ SmallVector<ReassociationIndices, 3 > ioReshapeIndices = {{0 }, {1 , 2 }, {3 }};
904
+ SmallVector<ReassociationIndices, 3 > fReshapeIndices = {{0 , 1 }, {2 }, {3 }};
905
+
906
+ // Reshape all operands for 1-D convolution.
907
+ Location loc = convOp.getLoc ();
908
+ Value newInput = rewriter.create <linalg::TensorCollapseShapeOp>(
909
+ loc, newInputType, input, ioReshapeIndices);
910
+ Value newFilter = rewriter.create <linalg::TensorCollapseShapeOp>(
911
+ loc, newFilterType, filter, fReshapeIndices );
912
+ Value newOutput = rewriter.create <linalg::TensorCollapseShapeOp>(
913
+ loc, newOutputType, output, ioReshapeIndices);
914
+
915
+ // We need to shrink the strides and dilations too.
916
+ auto stride = convOp.strides ().getFlatValue <int64_t >(removeH ? 1 : 0 );
917
+ auto stridesAttr = rewriter.getI64VectorAttr (stride);
918
+ auto dilation = convOp.dilations ().getFlatValue <int64_t >(removeH ? 1 : 0 );
919
+ auto dilationsAttr = rewriter.getI64VectorAttr (dilation);
920
+
921
+ auto conv1DOp = rewriter.create <linalg::Conv1DNwcWcfOp>(
922
+ loc, newOutputType, ValueRange{newInput, newFilter},
923
+ ValueRange{newOutput}, stridesAttr, dilationsAttr);
924
+
925
+ rewriter.replaceOpWithNewOp <linalg::TensorExpandShapeOp>(
926
+ convOp, outputType, conv1DOp.getResult (0 ), ioReshapeIndices);
927
+ return success ();
928
+ };
929
+ };
930
+
931
+ } // namespace
932
+
933
+ void linalg::populateDecomposeConvolutionPatterns (RewritePatternSet &patterns,
934
+ PatternBenefit benefit) {
935
+ patterns.add <DownscaleSizeOneWindowed2DConvolution>(patterns.getContext (),
936
+ benefit);
937
+ }
0 commit comments