@@ -2739,6 +2739,122 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2739
2739
return SmallVector<Value>{result};
2740
2740
}
2741
2741
2742
+ // ===----------------------------------------------------------------------===//
2743
+ // WinogradFilterTransformOp
2744
+ // ===----------------------------------------------------------------------===//
2745
+
2746
+ LogicalResult WinogradFilterTransformOp::verify () {
2747
+ auto filterType = cast<ShapedType>(getFilter ().getType ());
2748
+ ArrayRef<int64_t > filterShape = filterType.getShape ();
2749
+ int64_t filterH = filterShape[1 ];
2750
+ int64_t filterW = filterShape[2 ];
2751
+ int64_t r = getR ();
2752
+ int64_t m = getM ();
2753
+
2754
+ if (filterH != r && filterH != 1 )
2755
+ return emitOpError (" expect filter height either equals to r or 1" );
2756
+ if (filterW != r && filterW != 1 )
2757
+ return emitOpError (" expect filter width either equals to r or 1" );
2758
+ if (filterH == 1 && filterW == 1 )
2759
+ return emitOpError (" expect either filter height or width equals to r" );
2760
+
2761
+ SmallVector<int64_t > expectedOutputShape;
2762
+ expectedOutputShape.push_back (filterH == r ? m + r - 1 : 1 );
2763
+ expectedOutputShape.push_back (filterW == r ? m + r - 1 : 1 );
2764
+ expectedOutputShape.push_back (filterShape[3 ]);
2765
+ expectedOutputShape.push_back (filterShape[0 ]);
2766
+
2767
+ auto outputType = cast<ShapedType>(getOutput ().getType ());
2768
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
2769
+ if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
2770
+ return emitOpError (" the output shape is not expected" );
2771
+ }
2772
+ return success ();
2773
+ }
2774
+
2775
+ // ===----------------------------------------------------------------------===//
2776
+ // WinogradInputTransformOp
2777
+ // ===----------------------------------------------------------------------===//
2778
+
2779
+ LogicalResult WinogradInputTransformOp::verify () {
2780
+ auto inputType = cast<ShapedType>(getInput ().getType ());
2781
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
2782
+ int64_t inputH = inputShape[1 ];
2783
+ int64_t inputW = inputShape[2 ];
2784
+ int m = getM ();
2785
+ int r = getR ();
2786
+ int64_t tileSize = m + r - 1 ;
2787
+ bool leftTransform = inputH != 1 ;
2788
+ bool rightTransform = inputW != 1 ;
2789
+
2790
+ SmallVector<int64_t > expectedOutputShape (6 , inputH);
2791
+ if (ShapedType::isDynamic (inputH)) {
2792
+ expectedOutputShape[0 ] = tileSize;
2793
+ expectedOutputShape[2 ] = ShapedType::kDynamic ;
2794
+ } else {
2795
+ expectedOutputShape[0 ] = leftTransform ? tileSize : 1 ;
2796
+ expectedOutputShape[2 ] = leftTransform ? (inputH - (r - 1 )) / m : 1 ;
2797
+ }
2798
+ if (ShapedType::isDynamic (inputW)) {
2799
+ expectedOutputShape[1 ] = tileSize;
2800
+ expectedOutputShape[3 ] = ShapedType::kDynamic ;
2801
+ } else {
2802
+ expectedOutputShape[1 ] = rightTransform ? tileSize : 1 ;
2803
+ expectedOutputShape[3 ] = rightTransform ? (inputW - (r - 1 )) / m : 1 ;
2804
+ }
2805
+ expectedOutputShape[4 ] = inputShape[0 ];
2806
+ expectedOutputShape[5 ] = inputShape[3 ];
2807
+
2808
+ auto outputType = cast<ShapedType>(getOutput ().getType ());
2809
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
2810
+ if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
2811
+ return emitOpError (" the output shape is not expected" );
2812
+ }
2813
+ return success ();
2814
+ }
2815
+
2816
+ // ===----------------------------------------------------------------------===//
2817
+ // WinogradOutputTransformOp
2818
+ // ===----------------------------------------------------------------------===//
2819
+
2820
+ LogicalResult WinogradOutputTransformOp::verify () {
2821
+ auto valueType = cast<ShapedType>(getValue ().getType ());
2822
+ ArrayRef<int64_t > valueShape = valueType.getShape ();
2823
+ int64_t valueH = valueShape[0 ];
2824
+ int64_t valueW = valueShape[1 ];
2825
+ int64_t valueTileH = valueShape[2 ];
2826
+ int64_t valueTileW = valueShape[3 ];
2827
+ int m = getM ();
2828
+ int r = getR ();
2829
+ bool leftTransform = valueH != 1 ;
2830
+ bool rightTransform = valueW != 1 ;
2831
+
2832
+ SmallVector<int64_t > expectedOutputShape (4 , valueH);
2833
+ if (ShapedType::isDynamic (valueH) || ShapedType::isDynamic (valueTileH)) {
2834
+ expectedOutputShape[1 ] = ShapedType::kDynamic ;
2835
+ } else {
2836
+ if (valueH != (leftTransform ? m + r - 1 : 1 ))
2837
+ return emitOpError (" expect input height equals to input tile size" );
2838
+ expectedOutputShape[1 ] = (leftTransform ? m : 1 ) * valueTileH;
2839
+ }
2840
+ if (ShapedType::isDynamic (valueW) || ShapedType::isDynamic (valueTileW)) {
2841
+ expectedOutputShape[2 ] = ShapedType::kDynamic ;
2842
+ } else {
2843
+ if (valueW != (rightTransform ? m + r - 1 : 1 ))
2844
+ return emitOpError (" expect input width equals to input tile size" );
2845
+ expectedOutputShape[2 ] = (rightTransform ? m : 1 ) * valueTileW;
2846
+ }
2847
+ expectedOutputShape[0 ] = valueShape[4 ];
2848
+ expectedOutputShape[3 ] = valueShape[5 ];
2849
+
2850
+ auto outputType = cast<ShapedType>(getOutput ().getType ());
2851
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
2852
+ if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
2853
+ return emitOpError (" the output shape is not expected" );
2854
+ }
2855
+ return success ();
2856
+ }
2857
+
2742
2858
// ===----------------------------------------------------------------------===//
2743
2859
// LinalgDialect
2744
2860
// ===----------------------------------------------------------------------===//
0 commit comments