-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Improve lowering of tosa.conv2d #74143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation. This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation. Rather than: %init = tensor.empty() %conv = linalg.conv_2d ins(%A, %B) %outs(%init) %init = tensor.empty() %bias = linalg.generic ins(%conv, %bias) outs(%init2) { // perform add operation } The lowering now produces: %init = tensor.empty() %bias_expanded = linalg.broadcast ins(%bias) outs(%init) %conv = linalg.conv_2d ins(%A, %B) %outs(%bias)
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Spenser Bauman (sabauma) ChangesThe existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation. This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation. Rather than:
The lowering now produces:
This is the same strategy as #73049 applied to convolutions. Full diff: https://github.com/llvm/llvm-project/pull/74143.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 0accd9d1986a1ed..b3fbc7dd0b22c19 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -344,15 +344,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
weightPermValue);
}
- auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, filteredDims);
- Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
- .result();
-
// Extract the attributes for convolution.
ArrayRef<int64_t> stride = strideTosaAttr;
ArrayRef<int64_t> dilation = dilationTosaAttr;
@@ -361,18 +352,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto strideAttr = rewriter.getI64TensorAttr(stride);
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
- // Create maps for the bias broadcasting
- SmallVector<AffineMap, 4> indexingMaps;
- indexingMaps.push_back(AffineMap::get(
- /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
- {rewriter.getAffineDimExpr(resultTy.getRank() - 1)},
- rewriter.getContext()));
- indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
- indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultETy, filteredDims);
+ Value broadcastBias =
+ linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+
if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
@@ -380,38 +365,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
+
Value conv =
rewriter
.create<LinalgConvQOp>(
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
- Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, conv,
- biasEmptyTensor, indexingMaps);
- rewriter.replaceOp(op, result);
+
+ rewriter.replaceOp(op, conv);
return success();
}
Value conv = rewriter
.create<LinalgConvOp>(
loc, resultTy, ValueRange{input, weight},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
- Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor,
- indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added = nestedBuilder.create<arith::AddFOp>(
- loc, args[0], args[1]);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
- })
- .getResult(0);
-
- rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, conv);
return success();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 230001f7633b570..aa010e759a0f201 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -378,16 +378,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
- // CHECK: %[[M_IN:.+]] = tensor.empty()
- // CHECK: %[[CST:.+]] = arith.constant 0
- // CHECK: %[[FILL:.+]] = linalg.fill
- // CHECK: %[[B_IN:.+]] = tensor.empty()
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
- // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
- // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
// CHECK: arith.extsi
- // CHECK: arith.addi
// CHECK: linalg.yield
+ // CHECK: } -> tensor<1x45x40x28xi32>
+ // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+ // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+
%0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
return
}
@@ -401,15 +399,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
- // CHECK: %[[M_IN:.+]] = tensor.empty()
- // CHECK: %[[CST:.+]] = arith.constant 0
- // CHECK: %[[FILL:.+]] = linalg.fill
- // CHECK: %[[B_IN:.+]] = tensor.empty()
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
- // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
- // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
- // CHECK: arith.addf
+
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
// CHECK: linalg.yield
+ // CHECK: } -> tensor<1x45x40x28xf32>
+ // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+
+ // HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -421,16 +418,14 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
// CHECK-LABEL: @conv2d_dyn
func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
- // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
- // CHECK: %[[CST:.+]] = arith.constant 0
- // CHECK: %[[FILL:.+]] = linalg.fill
- // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
- // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
- // CHECK: %[[ADD:.+]] = arith.addf
- // CHECK: linalg.yield %[[ADD]] : f32
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x49x42x27xf32>
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x45x40x28xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<?x45x40x28xf32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32):
+ // CHECK: linalg.yield %[[IN]] : f32
+ // CHECK: } -> tensor<?x45x40x28xf32>
+ // CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
return
}
@@ -481,14 +476,12 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
// CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
// Running convolution
- // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
- // CHECK: %[[CST:.+]] = arith.constant 0
- // CHECK: %[[FILL:.+]] = linalg.fill
- // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
- // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
- // CHECK: %[[ADD:.+]] = arith.addf
- // CHECK: linalg.yield %[[ADD]] : f32
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) : tensor<1x?x?x28xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x?x?x28xf32>) {
+ // CHECK: linalg.yield
+ // CHECK: } -> tensor<1x?x?x28xf32>
+ // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
+
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
return
}
@@ -678,52 +671,52 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
// -----
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
// CHECK-LABEL: @conv3d_f32
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
- // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
- // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
- // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
- // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
- // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>)
- // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
- // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf
+ // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic
+ // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%1 : tensor<1x47x45x43x28xf32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+ // CHECK: linalg.yield %[[IN]] : f32
+ // CHECK: } -> tensor<1x47x45x43x28xf32>
+ // CHECK: linalg.conv_3d_ndhwc_dhwcf
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
- // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
- // CHECK: %[[GENERIC:.+]] = linalg.generic
- // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
- // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>)
- // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) {
- // CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32):
- // CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32
- // CHECK: linalg.yield %[[ADD]]
+ // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32>
return
}
// -----
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
// CHECK-LABEL: @conv3d_i8
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
- // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
- // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
- // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
- // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
- // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>)
- // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
- // CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32
- // CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32
- // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q
+ // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic
+ // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%arg2 : tensor<28xi32>)
+ // CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
+ // CHECK: linalg.yield %[[IN]] : i32
+ // CHECK: } -> tensor<1x47x45x43x28xi32>
+ // CHECK: %[[IZP:.+]] = arith.constant -128 : i32
+ // CHECK: %[[FZP:.+]] = arith.constant 42 : i32
+ // CHECK: linalg.conv_3d_ndhwc_dhwcf_q
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
- // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
- // CHECK: %[[GENERIC:.+]] = linalg.generic
- // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
- // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>)
- // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) {
- // CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32):
- // CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32
- // CHECK: linalg.yield %[[ADD]]
+ // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
+
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a nice simplification!
The existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation.
This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation.
Rather than:
The lowering now produces:
This is the same strategy as #73049 applied to convolutions.