@@ -2776,15 +2776,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
2776
2776
// WinogradInputTransformOp
2777
2777
// ===----------------------------------------------------------------------===//
2778
2778
2779
- Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
2780
- Location loc) {
2781
- if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2782
- auto intAttr = cast<IntegerAttr>(attr);
2783
- return builder.create <arith::ConstantOp>(loc, intAttr);
2784
- }
2785
- return opFoldResult.get <Value>();
2786
- }
2787
-
2788
2779
LogicalResult WinogradInputTransformOp::verify () {
2789
2780
auto inputType = cast<ShapedType>(getInput ().getType ());
2790
2781
ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2825,9 +2816,9 @@ LogicalResult WinogradInputTransformOp::verify() {
2825
2816
SmallVector<Range>
2826
2817
WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
2827
2818
Location loc = getLoc ();
2828
- auto indexType = builder.getIndexType ();
2829
- auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2830
- auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2819
+ IndexType indexType = builder.getIndexType ();
2820
+ IntegerAttr zeroAttr = builder.getIntegerAttr (indexType, 0 );
2821
+ IntegerAttr oneAttr = builder.getIntegerAttr (indexType, 1 );
2831
2822
Value output = getOutput ();
2832
2823
SmallVector<Range> loopBounds (6 );
2833
2824
for (unsigned dim = 0 ; dim < 6 ; ++dim) {
@@ -2849,21 +2840,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
2849
2840
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2850
2841
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2851
2842
SmallVector<OpFoldResult> &resultSizes) {
2852
- auto zeroAttr = builder.getI64IntegerAttr (0 );
2853
- auto oneAttr = builder.getI64IntegerAttr (1 );
2843
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
2844
+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
2854
2845
2855
- resultOffsets.push_back (zeroAttr);
2856
- resultOffsets.push_back (zeroAttr);
2857
- resultOffsets.push_back (offsets[2 ]);
2858
- resultOffsets.push_back (offsets[3 ]);
2859
- resultOffsets.push_back (zeroAttr);
2860
- resultOffsets.push_back (zeroAttr);
2861
- resultSizes.push_back (sizes[0 ]);
2862
- resultSizes.push_back (sizes[1 ]);
2863
- resultSizes.push_back (oneAttr);
2864
- resultSizes.push_back (oneAttr);
2865
- resultSizes.push_back (sizes[4 ]);
2866
- resultSizes.push_back (sizes[5 ]);
2846
+ resultOffsets.append (
2847
+ {zeroAttr, zeroAttr, offsets[2 ], offsets[3 ], zeroAttr, zeroAttr});
2848
+ resultSizes.append (
2849
+ {sizes[0 ], sizes[1 ], oneAttr, oneAttr, sizes[4 ], sizes[5 ]});
2867
2850
2868
2851
return success ();
2869
2852
}
@@ -2872,41 +2855,37 @@ FailureOr<TilingResult>
2872
2855
WinogradInputTransformOp::getTiledImplementation (OpBuilder &builder,
2873
2856
ArrayRef<OpFoldResult> offsets,
2874
2857
ArrayRef<OpFoldResult> sizes) {
2875
- auto oneAttr = builder.getI64IntegerAttr (1 );
2876
- auto zeroAttr = builder.getI64IntegerAttr (0 );
2858
+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
2859
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
2877
2860
Value input = getInput ();
2878
2861
auto inputType = cast<ShapedType>(input.getType ());
2879
- auto inputShape = inputType.getShape ();
2862
+ ArrayRef< int64_t > inputShape = inputType.getShape ();
2880
2863
int64_t inputH = inputShape[1 ];
2881
2864
int64_t inputW = inputShape[2 ];
2882
2865
int64_t m = getM ();
2883
2866
int64_t r = getR ();
2884
2867
int64_t alpha = m + r - 1 ;
2885
2868
int64_t alphaH = inputH != 1 ? alpha : 1 ;
2886
2869
int64_t alphaW = inputW != 1 ? alpha : 1 ;
2887
- auto alphaHAttr = builder.getI64IntegerAttr (alphaH);
2888
- auto alphaWAttr = builder.getI64IntegerAttr (alphaW);
2870
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr (alphaH);
2871
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr (alphaW);
2889
2872
2890
2873
Location loc = getLoc ();
2891
2874
SmallVector<Value> tiledOperands;
2892
2875
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2893
2876
2894
- auto context = builder.getContext ();
2877
+ MLIRContext * context = builder.getContext ();
2895
2878
auto affineMap =
2896
2879
AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
2897
2880
Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
2898
- loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2881
+ loc, affineMap,
2882
+ getValueOrCreateConstantIndexOp (builder, loc, offsets[2 ]));
2899
2883
Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
2900
- loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
2901
-
2902
- sliceOffsets.push_back (zeroAttr);
2903
- sliceOffsets.push_back (mappedOffset1);
2904
- sliceOffsets.push_back (mappedOffset2);
2905
- sliceOffsets.push_back (zeroAttr);
2906
- sliceSizes.push_back (sizes[4 ]);
2907
- sliceSizes.push_back (alphaHAttr);
2908
- sliceSizes.push_back (alphaWAttr);
2909
- sliceSizes.push_back (sizes[5 ]);
2884
+ loc, affineMap,
2885
+ getValueOrCreateConstantIndexOp (builder, loc, offsets[3 ]));
2886
+
2887
+ sliceOffsets.append ({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
2888
+ sliceSizes.append ({sizes[4 ], alphaHAttr, alphaWAttr, sizes[5 ]});
2910
2889
SmallVector<OpFoldResult> inputStrides (4 , oneAttr);
2911
2890
tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2912
2891
loc, getInput (), sliceOffsets, sliceSizes, inputStrides));
@@ -2921,7 +2900,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
2921
2900
tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2922
2901
loc, getOutput (), sliceOffsets, sliceSizes, outputStrides));
2923
2902
2924
- SmallVector<Type, 4 > resultTypes;
2903
+ SmallVector<Type> resultTypes;
2925
2904
resultTypes.push_back (tiledOperands[1 ].getType ());
2926
2905
Operation *tiledOp =
2927
2906
mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
@@ -2974,9 +2953,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
2974
2953
SmallVector<Range>
2975
2954
WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
2976
2955
Location loc = getLoc ();
2977
- auto indexType = builder.getIndexType ();
2978
- auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2979
- auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2956
+ IndexType indexType = builder.getIndexType ();
2957
+ IntegerAttr zeroAttr = builder.getIntegerAttr (indexType, 0 );
2958
+ IntegerAttr oneAttr = builder.getIntegerAttr (indexType, 1 );
2980
2959
Value value = getValue ();
2981
2960
SmallVector<Range> loopBounds (6 );
2982
2961
for (unsigned dim = 0 ; dim < 6 ; ++dim) {
@@ -2998,57 +2977,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
2998
2977
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2999
2978
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3000
2979
SmallVector<OpFoldResult> &resultSizes) {
3001
- auto zeroAttr = builder.getI64IntegerAttr (0 );
2980
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3002
2981
Value output = getOutput ();
3003
2982
auto outputType = cast<ShapedType>(output.getType ());
3004
- auto outputShape = outputType.getShape ();
2983
+ ArrayRef< int64_t > outputShape = outputType.getShape ();
3005
2984
int64_t outputH = outputShape[1 ];
3006
2985
int64_t outputW = outputShape[2 ];
3007
2986
int64_t m = getM ();
3008
- auto heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
3009
- auto widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
2987
+ IntegerAttr heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
2988
+ IntegerAttr widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
3010
2989
3011
2990
Location loc = getLoc ();
3012
- auto context = builder.getContext ();
2991
+ MLIRContext * context = builder.getContext ();
3013
2992
auto affineMap =
3014
2993
AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
3015
2994
Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3016
- loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2995
+ loc, affineMap,
2996
+ getValueOrCreateConstantIndexOp (builder, loc, offsets[2 ]));
3017
2997
Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3018
- loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3019
-
3020
- resultOffsets.push_back (zeroAttr);
3021
- resultOffsets.push_back (mappedOffset1);
3022
- resultOffsets.push_back (mappedOffset2);
3023
- resultOffsets.push_back (zeroAttr);
3024
- resultSizes.push_back (sizes[4 ]);
3025
- resultSizes.push_back (heightM);
3026
- resultSizes.push_back (widthM);
3027
- resultSizes.push_back (sizes[5 ]);
2998
+ loc, affineMap,
2999
+ getValueOrCreateConstantIndexOp (builder, loc, offsets[3 ]));
3000
+
3001
+ resultOffsets.append ({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
3002
+ resultSizes.append ({sizes[4 ], heightM, widthM, sizes[5 ]});
3028
3003
return success ();
3029
3004
}
3030
3005
3031
3006
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation (
3032
3007
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3033
3008
ArrayRef<OpFoldResult> sizes) {
3034
- auto oneAttr = builder.getI64IntegerAttr (1 );
3035
- auto zeroAttr = builder.getI64IntegerAttr (0 );
3009
+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3010
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3036
3011
Location loc = getLoc ();
3037
3012
SmallVector<Value> tiledOperands;
3038
3013
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3039
3014
3040
- sliceOffsets.push_back (zeroAttr);
3041
- sliceOffsets.push_back (zeroAttr);
3042
- sliceOffsets.push_back (offsets[2 ]);
3043
- sliceOffsets.push_back (offsets[3 ]);
3044
- sliceOffsets.push_back (zeroAttr);
3045
- sliceOffsets.push_back (zeroAttr);
3046
- sliceSizes.push_back (sizes[0 ]);
3047
- sliceSizes.push_back (sizes[1 ]);
3048
- sliceSizes.push_back (oneAttr);
3049
- sliceSizes.push_back (oneAttr);
3050
- sliceSizes.push_back (sizes[4 ]);
3051
- sliceSizes.push_back (sizes[5 ]);
3015
+ sliceOffsets.append (
3016
+ {zeroAttr, zeroAttr, offsets[2 ], offsets[3 ], zeroAttr, zeroAttr});
3017
+ sliceSizes.append ({sizes[0 ], sizes[1 ], oneAttr, oneAttr, sizes[4 ], sizes[5 ]});
3052
3018
SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
3053
3019
tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3054
3020
loc, getValue (), sliceOffsets, sliceSizes, sliceStrides));
@@ -3063,7 +3029,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3063
3029
tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3064
3030
loc, getOutput (), sliceOffsets, sliceSizes, strides));
3065
3031
3066
- SmallVector<Type, 4 > resultTypes;
3032
+ SmallVector<Type> resultTypes;
3067
3033
resultTypes.push_back (tiledOperands[1 ].getType ());
3068
3034
Operation *tiledOp =
3069
3035
mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
0 commit comments