Skip to content

Commit 659cca7

Browse files
committed
Revert "[mlir][TOSA] Fix linalg lowering of depthwise conv2d (llvm#130282)"
This reverts commit e22579a.
1 parent d31a7dd commit 659cca7

File tree

2 files changed

+19
-43
lines changed

2 files changed

+19
-43
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -477,21 +477,27 @@ class DepthwiseConvConverter
477477
return rewriter.notifyMatchFailure(
478478
op, "weight zero point must be zero for non-int8 integer types");
479479

480+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
480481
auto weightShape = weightTy.getShape();
481482
auto resultShape = resultTy.getShape();
482483

483484
// Apply padding as necessary.
484-
int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
485-
.getSExtValue();
486-
int64_t intMax = APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
487-
.getSExtValue();
485+
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
486+
if (hasZp) {
487+
int64_t intMin =
488+
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
489+
.getSExtValue();
490+
int64_t intMax =
491+
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
492+
.getSExtValue();
488493

489-
if (inputZpVal < intMin || inputZpVal > intMax)
490-
return rewriter.notifyMatchFailure(
491-
op, "tosa.depthwise_conv op quantization has zp outside of input "
492-
"range");
494+
if (inputZpVal < intMin || inputZpVal > intMax)
495+
return rewriter.notifyMatchFailure(
496+
op, "tosa.depthwise_conv op quantization has zp outside of input "
497+
"range");
493498

494-
TypedAttr zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
499+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
500+
}
495501

496502
llvm::SmallVector<int64_t> pad;
497503
pad.resize(2, 0);
@@ -530,7 +536,7 @@ class DepthwiseConvConverter
530536
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
531537
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
532538

533-
if (inputZpVal == 0 && weightZpVal == 0) {
539+
if (!hasZp) {
534540
Value conv = rewriter
535541
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
536542
loc, linalgConvTy, ValueRange{input, weight},
@@ -550,13 +556,8 @@ class DepthwiseConvConverter
550556
getNParallelLoopsAttrs(resultRank),
551557
[&](OpBuilder &nestedBuilder, Location nestedLoc,
552558
ValueRange args) {
553-
Value added;
554-
if (llvm::isa<FloatType>(inputETy))
555-
added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
556-
args[1]);
557-
else
558-
added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
559-
args[1]);
559+
Value added = nestedBuilder.create<arith::AddFOp>(
560+
loc, args[0], args[1]);
560561
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
561562
})
562563
.getResult(0);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -798,10 +798,9 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
798798
// CHECK: arith.subi
799799
// CHECK: arith.muli
800800
// CHECK: arith.divui
801-
// CHECK: [[CST0:%.+]] = arith.constant 0
802801
// CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 3, 0] high[0, 2, 4, 0] {
803802
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: index, %[[ARG4:[0-9a-zA-Z_]+]]: index, %[[ARG5:[0-9a-zA-Z_]+]]: index, %[[ARG6:[0-9a-zA-Z_]+]]: index):
804-
// CHECK: tensor.yield [[CST0]] : f32
803+
// CHECK: tensor.yield %cst : f32
805804
// CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32>
806805
// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%{{.*}} : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32>
807806
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]]
@@ -813,30 +812,6 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
813812

814813
// -----
815814

816-
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
817-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
818-
819-
// CHECK-LABEL: @depthwise_int_conv_zero_zp
820-
func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor<3x1x3x11xi8>, %arg2 : tensor<33xi32>) -> () {
821-
// CHECK: [[INIT:%.+]] = tensor.empty()
822-
// CHECK: [[CST0:%.+]] = arith.constant 0
823-
// CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
824-
// CHECK: [[OUT:%.+]] = tensor.empty()
825-
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xi8>, tensor<3x1x3x11xi8>) outs([[FILL]] : tensor<1x5x5x3x11xi32>)
826-
// CHECK: [[COLLAPSED:%.+]] = tensor.collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
827-
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xi32>, tensor<1x5x5x33xi32>) outs([[OUT]] : tensor<1x5x5x33xi32>) {
828-
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: i32, %[[ARG4:[0-9a-zA-Z_]+]]: i32, %[[ARG5:[0-9a-zA-Z_]+]]: i32):
829-
// CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
830-
// CHECK: linalg.yield [[ADD]] : i32
831-
// CHECK: } -> tensor<1x5x5x33xi32>
832-
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
833-
%weight_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
834-
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xi8>, tensor<3x1x3x11xi8>, tensor<33xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x5x33xi32>
835-
return
836-
}
837-
838-
// -----
839-
840815
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
841816
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
842817

0 commit comments

Comments
 (0)