@@ -1029,56 +1029,49 @@ class TransposeConvConverter
1029
1029
getValuesFromIntArrayAttribute (op.stride ().cast <ArrayAttr>(), stride);
1030
1030
getValuesFromIntArrayAttribute (op.dilation ().cast <ArrayAttr>(), dilation);
1031
1031
1032
- // We have not solved for stride / dilation yet. Dilation should be
1033
- // straight forward but stride is more complicated. Linalg work is likely
1034
- // required for efficient implementation.
1035
- if (llvm::any_of (stride, [](int64_t v) { return v != 1 ; }))
1036
- return failure ();
1037
- if (llvm::any_of (dilation, [](int64_t v) { return v != 1 ; }))
1038
- return failure ();
1039
-
1040
- if (!inputTy.hasStaticShape () || !weightTy.hasStaticShape () ||
1041
- !biasTy.hasStaticShape () || !resultTy.hasStaticShape ())
1042
- return failure ();
1032
+ // If striding is all 1 we can modify padding and reverse the kernel along
1033
+ // the x/y direction to make it a regular convolution. This is much simpler
1034
+ // then handling striding....
1035
+ if (llvm::all_of (stride, [](int64_t v) { return v == 1 ; })) {
1036
+ if (!inputTy.hasStaticShape () || !weightTy.hasStaticShape () ||
1037
+ !biasTy.hasStaticShape () || !resultTy.hasStaticShape ())
1038
+ return failure ();
1039
+
1040
+ int64_t kernelHeight = (weightTy.getDimSize (1 ) - 1 ) * dilation[0 ] + 1 ;
1041
+ int64_t kernelWidth = (weightTy.getDimSize (2 ) - 1 ) * dilation[1 ] + 1 ;
1042
+ int64_t requiredInputHeight = resultTy.getDimSize (1 ) + kernelHeight - 1 ;
1043
+ int64_t requiredInputWidth = resultTy.getDimSize (2 ) + kernelWidth - 1 ;
1044
+
1045
+ llvm::SmallVector<int64_t > convPad (4 , 0 );
1046
+ convPad[0 ] = kernelHeight - 1 - pad[0 ];
1047
+ convPad[2 ] = kernelWidth - 1 - pad[1 ];
1048
+ convPad[1 ] = requiredInputHeight - convPad[0 ] - inputTy.getDimSize (1 );
1049
+ convPad[3 ] = requiredInputWidth - convPad[2 ] - inputTy.getDimSize (2 );
1050
+
1051
+ auto reverse1 = rewriter.create <tosa::ReverseOp>(
1052
+ loc, weightTy, weight, rewriter.getI64IntegerAttr (1 ));
1053
+ auto reverse2 = rewriter.create <tosa::ReverseOp>(
1054
+ loc, weightTy, reverse1, rewriter.getI64IntegerAttr (2 ));
1055
+
1056
+ Value conv2d;
1057
+ if (op.quantization_info ().hasValue ()) {
1058
+ conv2d = rewriter.create <tosa::Conv2DOp>(
1059
+ loc, resultTy, input, reverse2, bias,
1060
+ rewriter.getI64ArrayAttr (convPad), rewriter.getI64ArrayAttr (stride),
1061
+ rewriter.getI64ArrayAttr (dilation),
1062
+ op.quantization_info ().getValue ());
1063
+ } else {
1064
+ conv2d = rewriter.create <tosa::Conv2DOp>(
1065
+ loc, resultTy, input, reverse2, bias,
1066
+ rewriter.getI64ArrayAttr (convPad), rewriter.getI64ArrayAttr (stride),
1067
+ rewriter.getI64ArrayAttr (dilation));
1068
+ }
1043
1069
1044
- int64_t inputHeight = inputTy.getDimSize (1 );
1045
- int64_t inputWidth = inputTy.getDimSize (2 );
1046
- int64_t kernelHeight = weightTy.getDimSize (1 );
1047
- int64_t kernelWidth = weightTy.getDimSize (2 );
1048
- int64_t outputHeight = resultTy.getDimSize (1 );
1049
- int64_t outputWidth = resultTy.getDimSize (2 );
1050
-
1051
- int64_t requiredInputHeight = outputHeight + kernelHeight - 1 ;
1052
- int64_t requiredInputWidth = outputWidth + kernelWidth - 1 ;
1053
-
1054
- llvm::SmallVector<int64_t > newPad (4 , 0 );
1055
- newPad[0 ] = kernelHeight - 1 - pad[0 ];
1056
- newPad[2 ] = kernelWidth - 1 - pad[1 ];
1057
-
1058
- newPad[1 ] = requiredInputHeight - newPad[0 ] - inputHeight;
1059
- newPad[3 ] = requiredInputWidth - newPad[2 ] - inputWidth;
1060
-
1061
- auto reverse1 = rewriter.create <tosa::ReverseOp>(
1062
- loc, weightTy, weight, rewriter.getI64IntegerAttr (1 ));
1063
- auto reverse2 = rewriter.create <tosa::ReverseOp>(
1064
- loc, weightTy, reverse1, rewriter.getI64IntegerAttr (2 ));
1065
-
1066
- Value conv2d;
1067
- if (op.quantization_info ().hasValue ()) {
1068
- conv2d = rewriter.create <tosa::Conv2DOp>(
1069
- loc, resultTy, input, reverse2, bias,
1070
- rewriter.getI64ArrayAttr (newPad), rewriter.getI64ArrayAttr (stride),
1071
- rewriter.getI64ArrayAttr (dilation),
1072
- op.quantization_info ().getValue ());
1073
- } else {
1074
- conv2d = rewriter.create <tosa::Conv2DOp>(
1075
- loc, resultTy, input, reverse2, bias,
1076
- rewriter.getI64ArrayAttr (newPad), rewriter.getI64ArrayAttr (stride),
1077
- rewriter.getI64ArrayAttr (dilation));
1070
+ rewriter.replaceOp (op, conv2d);
1071
+ return success ();
1078
1072
}
1079
1073
1080
- rewriter.replaceOp (op, conv2d);
1081
- return success ();
1074
+ return failure ();
1082
1075
}
1083
1076
};
1084
1077
0 commit comments