26
26
#include " mlir/Transforms/DialectConversion.h"
27
27
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
28
29
+ #include " mlir/Interfaces/InferTypeOpInterface.h"
30
+
29
31
#include < numeric>
30
32
#include < type_traits>
31
33
@@ -34,7 +36,7 @@ using namespace mlir::tosa;
34
36
35
37
static mlir::Value applyPad (Location loc, Value input, ArrayRef<int64_t > pad,
36
38
TypedAttr padAttr, OpBuilder &rewriter) {
37
- // Input should be padded if necessary.
39
+ // Input should be padded only if necessary.
38
40
if (llvm::all_of (pad, [](int64_t p) { return p == 0 ; }))
39
41
return input;
40
42
@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
47
49
SmallVector<int64_t , 4 > paddedShape;
48
50
SmallVector<OpFoldResult, 8 > lowIndices;
49
51
SmallVector<OpFoldResult, 8 > highIndices;
50
- for (int i = 0 , s = inputShape.size (); i < s; i++ ) {
52
+ for (size_t i : llvm::seq ( inputShape.size ()) ) {
51
53
auto lowPad = pad[i * 2 ];
52
54
auto highPad = pad[i * 2 + 1 ];
53
55
if (ShapedType::isDynamic (inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
131
133
132
134
static mlir::Value reifyConstantDim (int64_t attr,
133
135
ImplicitLocOpBuilder &builder) {
134
- return builder.createOrFold <arith::IndexCastOp>(
135
- builder.getIndexType (),
136
- builder.create <arith::ConstantOp>(builder.getI64IntegerAttr (attr)));
136
+ return builder.create <arith::ConstantIndexOp>(attr);
137
137
}
138
138
139
139
// Calculating the output width/height using the formula:
140
140
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
141
141
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
142
142
143
- static mlir::Value getConvOutputDim (Location loc, Value inputDim,
144
- int64_t padBeforeAttr, int64_t padAfterAttr,
145
- Value kernelDim, int64_t strideAttr,
146
- int64_t dilationAttr, Type inputETy,
147
- OpBuilder &rewriter) {
143
+ static mlir::Value getConvOrPoolOutputDim (Location loc, Value inputDim,
144
+ int64_t padBeforeAttr,
145
+ int64_t padAfterAttr, Value kernelDim,
146
+ int64_t strideAttr,
147
+ int64_t dilationAttr,
148
+ OpBuilder &rewriter) {
148
149
ImplicitLocOpBuilder builder (loc, rewriter);
149
150
auto one = rewriter.create <arith::ConstantOp>(
150
151
loc, IntegerAttr::get (inputDim.getType (), 1 ));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
171
172
ArrayRef<int64_t > dilationAttr, ArrayRef<int64_t > inputSizeDims,
172
173
ArrayRef<int64_t > kernelSizeDims, OpBuilder &rewriter) {
173
174
ShapedType inputTy = cast<ShapedType>(input.getType ());
174
- Type inputETy = inputTy.getElementType ();
175
175
int64_t inputRank = inputTy.getRank ();
176
176
177
177
SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
190
190
rewriter.create <tensor::DimOp>(loc, weight, kernelDim);
191
191
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
192
192
dynDims[inputDim] =
193
- getConvOutputDim (loc, initDynDim, padTop, padBottom, kernelDynDim ,
194
- stride, dilation, inputETy , rewriter);
193
+ getConvOrPoolOutputDim (loc, initDynDim, padTop, padBottom,
194
+ kernelDynDim, stride, dilation , rewriter);
195
195
}
196
196
}
197
197
@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
685
685
public:
686
686
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
687
687
688
+ // Compute the dynamic output sizes of the maxpool operation.
689
+ static SmallVector<Value>
690
+ computeDynamicOutputSizes (tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
691
+ TensorType resultTy = op.getType ();
692
+ Location loc = op.getLoc ();
693
+
694
+ TypedValue<TensorType> input = op.getInput ();
695
+ ArrayRef<int64_t > kernel = op.getKernel ();
696
+ ArrayRef<int64_t > pad = op.getPad ();
697
+ ArrayRef<int64_t > stride = op.getStride ();
698
+
699
+ SmallVector<Value> dynamicDims;
700
+
701
+ // Batch dimension
702
+ if (resultTy.isDynamicDim (0 ))
703
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
704
+
705
+ // Height/width dimensions
706
+ for (int64_t dim : {1 , 2 }) {
707
+ if (!resultTy.isDynamicDim (dim))
708
+ continue ;
709
+
710
+ // Index into the attribute arrays
711
+ int64_t index = dim - 1 ;
712
+
713
+ // Input height/width
714
+ Value ihw = rewriter.create <tensor::DimOp>(loc, input, dim);
715
+
716
+ // Kernel height/width
717
+ Value khw = rewriter.create <arith::ConstantIndexOp>(loc, kernel[index]);
718
+
719
+ // Output height/width
720
+ Value ohw = getConvOrPoolOutputDim (loc, ihw, pad[index * 2 ],
721
+ pad[index * 2 + 1 ], khw, stride[index],
722
+ /* dilationAttr=*/ 1 , rewriter);
723
+ dynamicDims.push_back (ohw);
724
+ }
725
+
726
+ // Channel dimension
727
+ if (resultTy.isDynamicDim (3 ))
728
+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 3 ));
729
+
730
+ return dynamicDims;
731
+ }
732
+
688
733
LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
689
734
PatternRewriter &rewriter) const final {
690
735
Location loc = op.getLoc ();
691
- Value input = op.getInput ();
692
- ShapedType inputTy = cast<ShapedType>( input.getType () );
736
+ TypedValue<TensorType> input = op.getInput ();
737
+ ShapedType inputTy = input.getType ();
693
738
694
- ShapedType resultTy = cast<ShapedType>( op.getType () );
739
+ ShapedType resultTy = op.getType ();
695
740
Type resultETy = inputTy.getElementType ();
696
741
697
- auto dynamicDimsOr =
698
- checkHasDynamicBatchDims (rewriter, op, {input, op.getOutput ()});
699
- if (!dynamicDimsOr.has_value ())
700
- return failure ();
701
- SmallVector<Value> dynamicDims = *dynamicDimsOr;
742
+ SmallVector<Value> dynamicDims = computeDynamicOutputSizes (op, rewriter);
702
743
703
744
// Determine what the initial value needs to be for the max pool op.
704
745
TypedAttr initialAttr;
@@ -721,6 +762,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
721
762
pad.resize (2 , 0 );
722
763
llvm::append_range (pad, op.getPad ());
723
764
pad.resize (pad.size () + 2 , 0 );
765
+
724
766
Value paddedInput = applyPad (loc, input, pad, initialAttr, rewriter);
725
767
726
768
Value initialValue = rewriter.create <arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
736
778
loc, resultTy.getShape (), resultTy.getElementType (), dynamicDims);
737
779
738
780
Value filledEmptyTensor =
739
- rewriter
740
- .create <linalg::FillOp>(loc, ValueRange{initialValue},
741
- ValueRange{emptyTensor})
781
+ rewriter.create <linalg::FillOp>(loc, initialValue, emptyTensor)
742
782
.result ();
743
783
744
784
Value fakeWindowDims =
0 commit comments