Skip to content

Commit 86858c6

Browse files
committed
[mlir][tosa] Add dilation to tosa.transpose_conv2d lowering
Dilation only requires increasing the padding on the left/right side of the input, and including dilation in the convolution. This implementation still lacks support for strided convolutions. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D107680
1 parent a1f4656 commit 86858c6

File tree

2 files changed

+50
-47
lines changed

2 files changed

+50
-47
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,56 +1029,49 @@ class TransposeConvConverter
10291029
getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
10301030
getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
10311031

1032-
// We have not solved for stride / dilation yet. Dilation should be
1033-
// straight forward but stride is more complicated. Linalg work is likely
1034-
// required for efficient implementation.
1035-
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
1036-
return failure();
1037-
if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
1038-
return failure();
1039-
1040-
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1041-
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1042-
return failure();
1032+
// If striding is all 1 we can modify padding and reverse the kernel along
1033+
// the x/y direction to make it a regular convolution. This is much simpler
1034+
// then handling striding....
1035+
if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
1036+
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1037+
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1038+
return failure();
1039+
1040+
int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
1041+
int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
1042+
int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
1043+
int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
1044+
1045+
llvm::SmallVector<int64_t> convPad(4, 0);
1046+
convPad[0] = kernelHeight - 1 - pad[0];
1047+
convPad[2] = kernelWidth - 1 - pad[1];
1048+
convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
1049+
convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
1050+
1051+
auto reverse1 = rewriter.create<tosa::ReverseOp>(
1052+
loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
1053+
auto reverse2 = rewriter.create<tosa::ReverseOp>(
1054+
loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
1055+
1056+
Value conv2d;
1057+
if (op.quantization_info().hasValue()) {
1058+
conv2d = rewriter.create<tosa::Conv2DOp>(
1059+
loc, resultTy, input, reverse2, bias,
1060+
rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
1061+
rewriter.getI64ArrayAttr(dilation),
1062+
op.quantization_info().getValue());
1063+
} else {
1064+
conv2d = rewriter.create<tosa::Conv2DOp>(
1065+
loc, resultTy, input, reverse2, bias,
1066+
rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
1067+
rewriter.getI64ArrayAttr(dilation));
1068+
}
10431069

1044-
int64_t inputHeight = inputTy.getDimSize(1);
1045-
int64_t inputWidth = inputTy.getDimSize(2);
1046-
int64_t kernelHeight = weightTy.getDimSize(1);
1047-
int64_t kernelWidth = weightTy.getDimSize(2);
1048-
int64_t outputHeight = resultTy.getDimSize(1);
1049-
int64_t outputWidth = resultTy.getDimSize(2);
1050-
1051-
int64_t requiredInputHeight = outputHeight + kernelHeight - 1;
1052-
int64_t requiredInputWidth = outputWidth + kernelWidth - 1;
1053-
1054-
llvm::SmallVector<int64_t> newPad(4, 0);
1055-
newPad[0] = kernelHeight - 1 - pad[0];
1056-
newPad[2] = kernelWidth - 1 - pad[1];
1057-
1058-
newPad[1] = requiredInputHeight - newPad[0] - inputHeight;
1059-
newPad[3] = requiredInputWidth - newPad[2] - inputWidth;
1060-
1061-
auto reverse1 = rewriter.create<tosa::ReverseOp>(
1062-
loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
1063-
auto reverse2 = rewriter.create<tosa::ReverseOp>(
1064-
loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
1065-
1066-
Value conv2d;
1067-
if (op.quantization_info().hasValue()) {
1068-
conv2d = rewriter.create<tosa::Conv2DOp>(
1069-
loc, resultTy, input, reverse2, bias,
1070-
rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
1071-
rewriter.getI64ArrayAttr(dilation),
1072-
op.quantization_info().getValue());
1073-
} else {
1074-
conv2d = rewriter.create<tosa::Conv2DOp>(
1075-
loc, resultTy, input, reverse2, bias,
1076-
rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
1077-
rewriter.getI64ArrayAttr(dilation));
1070+
rewriter.replaceOp(op, conv2d);
1071+
return success();
10781072
}
10791073

1080-
rewriter.replaceOp(op, conv2d);
1081-
return success();
1074+
return failure();
10821075
}
10831076
};
10841077

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,16 @@ func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>,
12741274
return
12751275
}
12761276

1277+
// -----
1278+
1279+
// CHECK-LABEL: @transpose_conv_dilated
1280+
func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
1281+
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
1282+
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<4x3x3x2xf32>)
1283+
%0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
1284+
return
1285+
}
1286+
12771287

12781288
// -----
12791289

0 commit comments

Comments
 (0)