@@ -2755,93 +2755,19 @@ LogicalResult WinogradFilterTransformOp::verify() {
2755
2755
return success ();
2756
2756
}
2757
2757
2758
- SmallVector<Range>
2759
- WinogradFilterTransformOp::getIterationDomain (OpBuilder &builder) {
2760
- Location loc = getLoc ();
2761
- Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
2762
- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
2763
- Value output = getOutput ();
2764
- SmallVector<Range> loopBounds (6 );
2765
- for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2766
- loopBounds[dim].offset = zero;
2767
- loopBounds[dim].size = getDimValue (builder, loc, output, dim);
2768
- loopBounds[dim].stride = one;
2769
- }
2770
- return loopBounds;
2771
- }
2772
-
2773
- SmallVector<utils::IteratorType>
2774
- WinogradFilterTransformOp::getLoopIteratorTypes () {
2775
- SmallVector<utils::IteratorType> iteratorTypes (6 ,
2776
- utils::IteratorType::parallel);
2777
- return iteratorTypes;
2778
- }
2758
+ // ===----------------------------------------------------------------------===//
2759
+ // WinogradInputTransformOp
2760
+ // ===----------------------------------------------------------------------===//
2779
2761
2780
2762
Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
2781
2763
Location loc) {
2782
- if (auto val = opFoldResult.dyn_cast <Value>()) {
2783
- return val;
2784
- } else if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2764
+ if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2785
2765
auto intAttr = cast<IntegerAttr>(attr);
2786
2766
return builder.create <arith::ConstantOp>(loc, intAttr);
2787
2767
}
2788
- // This should never happen if OpFoldResult is correctly formed.
2789
- return nullptr ;
2768
+ return opFoldResult.get <Value>();
2790
2769
}
2791
2770
2792
- LogicalResult WinogradFilterTransformOp::getResultTilePosition (
2793
- OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2794
- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2795
- SmallVector<OpFoldResult> &resultSizes) {
2796
- auto zeroAttr = builder.getI64IntegerAttr (0 );
2797
- auto oneAttr = builder.getI64IntegerAttr (1 );
2798
-
2799
- resultOffsets.push_back (offsets[0 ]);
2800
- resultOffsets.push_back (offsets[1 ]);
2801
- resultOffsets.push_back (zeroAttr);
2802
- resultOffsets.push_back (zeroAttr);
2803
- resultOffsets.push_back (zeroAttr);
2804
- resultOffsets.push_back (zeroAttr);
2805
- resultSizes.push_back (oneAttr);
2806
- resultSizes.push_back (oneAttr);
2807
- resultSizes.push_back (sizes[2 ]);
2808
- resultSizes.push_back (sizes[3 ]);
2809
- resultSizes.push_back (sizes[4 ]);
2810
- resultSizes.push_back (sizes[5 ]);
2811
-
2812
- return success ();
2813
- }
2814
-
2815
- FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation (
2816
- OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
2817
- ArrayRef<OpFoldResult> sizes) {
2818
- auto oneAttr = builder.getI64IntegerAttr (1 );
2819
-
2820
- Location loc = getLoc ();
2821
- SmallVector<OpFoldResult> strides (6 , oneAttr);
2822
- SmallVector<Value> tiledOperands;
2823
- tiledOperands.emplace_back (getFilter ());
2824
-
2825
- SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2826
- if (failed (getResultTilePosition (builder, 1 , offsets, sizes, sliceOffsets,
2827
- sliceSizes)))
2828
- return failure ();
2829
-
2830
- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2831
- loc, getOutput (), sliceOffsets, sliceSizes, strides));
2832
-
2833
- SmallVector<Type, 4 > resultTypes;
2834
- resultTypes.push_back (tiledOperands[1 ].getType ());
2835
- Operation *tiledOp =
2836
- mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
2837
-
2838
- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
2839
- }
2840
-
2841
- // ===----------------------------------------------------------------------===//
2842
- // WinogradInputTransformOp
2843
- // ===----------------------------------------------------------------------===//
2844
-
2845
2771
LogicalResult WinogradInputTransformOp::verify () {
2846
2772
auto inputType = cast<ShapedType>(getInput ().getType ());
2847
2773
ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2887,14 +2813,15 @@ LogicalResult WinogradInputTransformOp::verify() {
2887
2813
SmallVector<Range>
2888
2814
WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
2889
2815
Location loc = getLoc ();
2890
- Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
2891
- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
2816
+ auto indexType = builder.getIndexType ();
2817
+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2818
+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2892
2819
Value output = getOutput ();
2893
2820
SmallVector<Range> loopBounds (6 );
2894
2821
for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2895
- loopBounds[dim].offset = zero ;
2822
+ loopBounds[dim].offset = zeroAttr ;
2896
2823
loopBounds[dim].size = getDimValue (builder, loc, output, dim);
2897
- loopBounds[dim].stride = one ;
2824
+ loopBounds[dim].stride = oneAttr ;
2898
2825
}
2899
2826
return loopBounds;
2900
2827
}
@@ -2913,16 +2840,16 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
2913
2840
auto zeroAttr = builder.getI64IntegerAttr (0 );
2914
2841
auto oneAttr = builder.getI64IntegerAttr (1 );
2915
2842
2916
- resultOffsets.push_back (offsets[0 ]);
2917
- resultOffsets.push_back (offsets[1 ]);
2918
2843
resultOffsets.push_back (zeroAttr);
2919
2844
resultOffsets.push_back (zeroAttr);
2845
+ resultOffsets.push_back (offsets[2 ]);
2846
+ resultOffsets.push_back (offsets[3 ]);
2920
2847
resultOffsets.push_back (zeroAttr);
2921
2848
resultOffsets.push_back (zeroAttr);
2849
+ resultSizes.push_back (sizes[0 ]);
2850
+ resultSizes.push_back (sizes[1 ]);
2922
2851
resultSizes.push_back (oneAttr);
2923
2852
resultSizes.push_back (oneAttr);
2924
- resultSizes.push_back (sizes[2 ]);
2925
- resultSizes.push_back (sizes[3 ]);
2926
2853
resultSizes.push_back (sizes[4 ]);
2927
2854
resultSizes.push_back (sizes[5 ]);
2928
2855
@@ -2956,9 +2883,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
2956
2883
auto affineMap =
2957
2884
AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
2958
2885
Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
2959
- loc, affineMap, getValueFromOpFoldResult (offsets[0 ], builder, loc));
2886
+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2960
2887
Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
2961
- loc, affineMap, getValueFromOpFoldResult (offsets[1 ], builder, loc));
2888
+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
2962
2889
2963
2890
sliceOffsets.push_back (zeroAttr);
2964
2891
sliceOffsets.push_back (mappedOffset1);
@@ -3033,14 +2960,15 @@ LogicalResult WinogradOutputTransformOp::verify() {
3033
2960
SmallVector<Range>
3034
2961
WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
3035
2962
Location loc = getLoc ();
3036
- Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
3037
- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
2963
+ auto indexType = builder.getIndexType ();
2964
+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2965
+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
3038
2966
Value value = getValue ();
3039
2967
SmallVector<Range> loopBounds (6 );
3040
2968
for (unsigned dim = 0 ; dim < 6 ; ++dim) {
3041
- loopBounds[dim].offset = zero ;
2969
+ loopBounds[dim].offset = zeroAttr ;
3042
2970
loopBounds[dim].size = getDimValue (builder, loc, value, dim);
3043
- loopBounds[dim].stride = one ;
2971
+ loopBounds[dim].stride = oneAttr ;
3044
2972
}
3045
2973
return loopBounds;
3046
2974
}
@@ -3071,9 +2999,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3071
2999
auto affineMap =
3072
3000
AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
3073
3001
Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3074
- loc, affineMap, getValueFromOpFoldResult (offsets[0 ], builder, loc));
3002
+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
3075
3003
Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3076
- loc, affineMap, getValueFromOpFoldResult (offsets[1 ], builder, loc));
3004
+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3077
3005
3078
3006
resultOffsets.push_back (zeroAttr);
3079
3007
resultOffsets.push_back (mappedOffset1);
@@ -3095,16 +3023,16 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3095
3023
SmallVector<Value> tiledOperands;
3096
3024
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3097
3025
3098
- sliceOffsets.push_back (offsets[0 ]);
3099
- sliceOffsets.push_back (offsets[1 ]);
3100
3026
sliceOffsets.push_back (zeroAttr);
3101
3027
sliceOffsets.push_back (zeroAttr);
3028
+ sliceOffsets.push_back (offsets[2 ]);
3029
+ sliceOffsets.push_back (offsets[3 ]);
3102
3030
sliceOffsets.push_back (zeroAttr);
3103
3031
sliceOffsets.push_back (zeroAttr);
3032
+ sliceSizes.push_back (sizes[0 ]);
3033
+ sliceSizes.push_back (sizes[1 ]);
3104
3034
sliceSizes.push_back (oneAttr);
3105
3035
sliceSizes.push_back (oneAttr);
3106
- sliceSizes.push_back (sizes[2 ]);
3107
- sliceSizes.push_back (sizes[3 ]);
3108
3036
sliceSizes.push_back (sizes[4 ]);
3109
3037
sliceSizes.push_back (sizes[5 ]);
3110
3038
SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
0 commit comments