@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
2776
2776
// WinogradInputTransformOp
2777
2777
// ===----------------------------------------------------------------------===//
2778
2778
2779
+ Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
2780
+ Location loc) {
2781
+ if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2782
+ auto intAttr = cast<IntegerAttr>(attr);
2783
+ return builder.create <arith::ConstantOp>(loc, intAttr);
2784
+ }
2785
+ return opFoldResult.get <Value>();
2786
+ }
2787
+
2779
2788
LogicalResult WinogradInputTransformOp::verify () {
2780
2789
auto inputType = cast<ShapedType>(getInput ().getType ());
2781
2790
ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2813,6 +2822,113 @@ LogicalResult WinogradInputTransformOp::verify() {
2813
2822
return success ();
2814
2823
}
2815
2824
2825
+ SmallVector<Range>
2826
+ WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
2827
+ Location loc = getLoc ();
2828
+ auto indexType = builder.getIndexType ();
2829
+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2830
+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2831
+ Value output = getOutput ();
2832
+ SmallVector<Range> loopBounds (6 );
2833
+ for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2834
+ loopBounds[dim].offset = zeroAttr;
2835
+ loopBounds[dim].size = getDimValue (builder, loc, output, dim);
2836
+ loopBounds[dim].stride = oneAttr;
2837
+ }
2838
+ return loopBounds;
2839
+ }
2840
+
2841
+ SmallVector<utils::IteratorType>
2842
+ WinogradInputTransformOp::getLoopIteratorTypes () {
2843
+ SmallVector<utils::IteratorType> iteratorTypes (6 ,
2844
+ utils::IteratorType::parallel);
2845
+ return iteratorTypes;
2846
+ }
2847
+
2848
+ LogicalResult WinogradInputTransformOp::getResultTilePosition (
2849
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2850
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2851
+ SmallVector<OpFoldResult> &resultSizes) {
2852
+ auto zeroAttr = builder.getI64IntegerAttr (0 );
2853
+ auto oneAttr = builder.getI64IntegerAttr (1 );
2854
+
2855
+ resultOffsets.push_back (zeroAttr);
2856
+ resultOffsets.push_back (zeroAttr);
2857
+ resultOffsets.push_back (offsets[2 ]);
2858
+ resultOffsets.push_back (offsets[3 ]);
2859
+ resultOffsets.push_back (zeroAttr);
2860
+ resultOffsets.push_back (zeroAttr);
2861
+ resultSizes.push_back (sizes[0 ]);
2862
+ resultSizes.push_back (sizes[1 ]);
2863
+ resultSizes.push_back (oneAttr);
2864
+ resultSizes.push_back (oneAttr);
2865
+ resultSizes.push_back (sizes[4 ]);
2866
+ resultSizes.push_back (sizes[5 ]);
2867
+
2868
+ return success ();
2869
+ }
2870
+
2871
+ FailureOr<TilingResult>
2872
+ WinogradInputTransformOp::getTiledImplementation (OpBuilder &builder,
2873
+ ArrayRef<OpFoldResult> offsets,
2874
+ ArrayRef<OpFoldResult> sizes) {
2875
+ auto oneAttr = builder.getI64IntegerAttr (1 );
2876
+ auto zeroAttr = builder.getI64IntegerAttr (0 );
2877
+ Value input = getInput ();
2878
+ auto inputType = cast<ShapedType>(input.getType ());
2879
+ auto inputShape = inputType.getShape ();
2880
+ int64_t inputH = inputShape[1 ];
2881
+ int64_t inputW = inputShape[2 ];
2882
+ int64_t m = getM ();
2883
+ int64_t r = getR ();
2884
+ int64_t alpha = m + r - 1 ;
2885
+ int64_t alphaH = inputH != 1 ? alpha : 1 ;
2886
+ int64_t alphaW = inputW != 1 ? alpha : 1 ;
2887
+ auto alphaHAttr = builder.getI64IntegerAttr (alphaH);
2888
+ auto alphaWAttr = builder.getI64IntegerAttr (alphaW);
2889
+
2890
+ Location loc = getLoc ();
2891
+ SmallVector<Value> tiledOperands;
2892
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2893
+
2894
+ auto context = builder.getContext ();
2895
+ auto affineMap =
2896
+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
2897
+ Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
2898
+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2899
+ Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
2900
+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
2901
+
2902
+ sliceOffsets.push_back (zeroAttr);
2903
+ sliceOffsets.push_back (mappedOffset1);
2904
+ sliceOffsets.push_back (mappedOffset2);
2905
+ sliceOffsets.push_back (zeroAttr);
2906
+ sliceSizes.push_back (sizes[4 ]);
2907
+ sliceSizes.push_back (alphaHAttr);
2908
+ sliceSizes.push_back (alphaWAttr);
2909
+ sliceSizes.push_back (sizes[5 ]);
2910
+ SmallVector<OpFoldResult> inputStrides (4 , oneAttr);
2911
+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2912
+ loc, getInput (), sliceOffsets, sliceSizes, inputStrides));
2913
+
2914
+ sliceOffsets.clear ();
2915
+ sliceSizes.clear ();
2916
+ if (failed (getResultTilePosition (builder, 1 , offsets, sizes, sliceOffsets,
2917
+ sliceSizes)))
2918
+ return failure ();
2919
+
2920
+ SmallVector<OpFoldResult> outputStrides (6 , oneAttr);
2921
+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2922
+ loc, getOutput (), sliceOffsets, sliceSizes, outputStrides));
2923
+
2924
+ SmallVector<Type, 4 > resultTypes;
2925
+ resultTypes.push_back (tiledOperands[1 ].getType ());
2926
+ Operation *tiledOp =
2927
+ mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
2928
+
2929
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
2930
+ }
2931
+
2816
2932
// ===----------------------------------------------------------------------===//
2817
2933
// WinogradOutputTransformOp
2818
2934
// ===----------------------------------------------------------------------===//
@@ -2855,6 +2971,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
2855
2971
return success ();
2856
2972
}
2857
2973
2974
+ SmallVector<Range>
2975
+ WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
2976
+ Location loc = getLoc ();
2977
+ auto indexType = builder.getIndexType ();
2978
+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2979
+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2980
+ Value value = getValue ();
2981
+ SmallVector<Range> loopBounds (6 );
2982
+ for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2983
+ loopBounds[dim].offset = zeroAttr;
2984
+ loopBounds[dim].size = getDimValue (builder, loc, value, dim);
2985
+ loopBounds[dim].stride = oneAttr;
2986
+ }
2987
+ return loopBounds;
2988
+ }
2989
+
2990
+ SmallVector<utils::IteratorType>
2991
+ WinogradOutputTransformOp::getLoopIteratorTypes () {
2992
+ SmallVector<utils::IteratorType> iteratorTypes (6 ,
2993
+ utils::IteratorType::parallel);
2994
+ return iteratorTypes;
2995
+ }
2996
+
2997
+ LogicalResult WinogradOutputTransformOp::getResultTilePosition (
2998
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2999
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3000
+ SmallVector<OpFoldResult> &resultSizes) {
3001
+ auto zeroAttr = builder.getI64IntegerAttr (0 );
3002
+ Value output = getOutput ();
3003
+ auto outputType = cast<ShapedType>(output.getType ());
3004
+ auto outputShape = outputType.getShape ();
3005
+ int64_t outputH = outputShape[1 ];
3006
+ int64_t outputW = outputShape[2 ];
3007
+ int64_t m = getM ();
3008
+ auto heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
3009
+ auto widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
3010
+
3011
+ Location loc = getLoc ();
3012
+ auto context = builder.getContext ();
3013
+ auto affineMap =
3014
+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
3015
+ Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3016
+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
3017
+ Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3018
+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3019
+
3020
+ resultOffsets.push_back (zeroAttr);
3021
+ resultOffsets.push_back (mappedOffset1);
3022
+ resultOffsets.push_back (mappedOffset2);
3023
+ resultOffsets.push_back (zeroAttr);
3024
+ resultSizes.push_back (sizes[4 ]);
3025
+ resultSizes.push_back (heightM);
3026
+ resultSizes.push_back (widthM);
3027
+ resultSizes.push_back (sizes[5 ]);
3028
+ return success ();
3029
+ }
3030
+
3031
+ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation (
3032
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3033
+ ArrayRef<OpFoldResult> sizes) {
3034
+ auto oneAttr = builder.getI64IntegerAttr (1 );
3035
+ auto zeroAttr = builder.getI64IntegerAttr (0 );
3036
+ Location loc = getLoc ();
3037
+ SmallVector<Value> tiledOperands;
3038
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3039
+
3040
+ sliceOffsets.push_back (zeroAttr);
3041
+ sliceOffsets.push_back (zeroAttr);
3042
+ sliceOffsets.push_back (offsets[2 ]);
3043
+ sliceOffsets.push_back (offsets[3 ]);
3044
+ sliceOffsets.push_back (zeroAttr);
3045
+ sliceOffsets.push_back (zeroAttr);
3046
+ sliceSizes.push_back (sizes[0 ]);
3047
+ sliceSizes.push_back (sizes[1 ]);
3048
+ sliceSizes.push_back (oneAttr);
3049
+ sliceSizes.push_back (oneAttr);
3050
+ sliceSizes.push_back (sizes[4 ]);
3051
+ sliceSizes.push_back (sizes[5 ]);
3052
+ SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
3053
+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3054
+ loc, getValue (), sliceOffsets, sliceSizes, sliceStrides));
3055
+
3056
+ sliceOffsets.clear ();
3057
+ sliceSizes.clear ();
3058
+ if (failed (getResultTilePosition (builder, 1 , offsets, sizes, sliceOffsets,
3059
+ sliceSizes)))
3060
+ return failure ();
3061
+
3062
+ SmallVector<OpFoldResult> strides (4 , oneAttr);
3063
+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3064
+ loc, getOutput (), sliceOffsets, sliceSizes, strides));
3065
+
3066
+ SmallVector<Type, 4 > resultTypes;
3067
+ resultTypes.push_back (tiledOperands[1 ].getType ());
3068
+ Operation *tiledOp =
3069
+ mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
3070
+
3071
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
3072
+ }
3073
+
2858
3074
// ===----------------------------------------------------------------------===//
2859
3075
// LinalgDialect
2860
3076
// ===----------------------------------------------------------------------===//
0 commit comments