Skip to content

[TOSA] Fix avgpool2d accum in wider type #80849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2024

Conversation

RoboTux
Copy link
Contributor

@RoboTux RoboTux commented Feb 6, 2024

Truncate result of avgpool when accumulation is done in a wider type
than the result element type, such as when doing a f16 avgpool2d with a
f32 accumulator type.

Truncate result of avgpool when accumulation is done in a wider type
than the result element type, such as when doing a f16 avgpool2d with a
f32 accumulator type.
@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Thomas Preud'homme (RoboTux)

Changes

Truncate result of avgpool when accumulation is done in a wider type
than the result element type, such as when doing a f16 avgpool2d with a
f32 accumulator type.


Full diff: https://github.com/llvm/llvm-project/pull/80849.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+91)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff..607a603cca810f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -890,6 +890,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
             poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
                           ->getResult(0);
+            if (accETy.getIntOrFloatBitWidth() >
+                resultETy.getIntOrFloatBitWidth())
+              poolVal =
+                  rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
           } else {
 
             // If we have quantization information we need to apply an offset
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa..51ebcad0797807 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -306,6 +306,97 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
 
 // -----
 
+// CHECK-LABEL: @avg_pool_f16_f32acc
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x62xf16>) {
+  // Apply padding to the input:
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK:   tensor.yield %[[F0]] : f16
+
+  // Fill the pooling target:
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
+
+  // Compute the sum padding:
+  // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
+  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
+  // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+  // CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf16>, tensor<4x4xf32>)
+  // CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>)
+
+  // Compute dimension based constants:
+  // CHECK: %[[I1:.+]] = arith.constant 1 : index
+  // CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]]
+  // CHECK: %[[I2:.+]] = arith.constant 2 : index
+  // CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]]
+  // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+  // CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index
+  // CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index
+
+  // Divide the sum pooling by the number of summed values.
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf16>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf16>)
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f16)
+  // CHECK:   %[[ZERO:.+]] = arith.constant 0
+
+  // Compute how much of the height does not include padding:
+  // CHECK:   %[[STRIDE:.+]] = arith.constant 1
+  // CHECK:   %[[KSIZE:.+]] = arith.constant 4
+  // CHECK:   %[[START:.+]] = linalg.index 1
+  // CHECK:   %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]]
+  // CHECK:   %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+  // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+  // CHECK:   %[[PAD_START:.+]] = arith.constant 1
+  // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+  // CHECK:   %[[PAD_END:.+]] = arith.constant 1
+  // CHECK:   %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+  // CHECK:   %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+  // Compute how much of the width does not include padding:
+  // CHECK:   %[[STRIDE:.+]] = arith.constant 1
+  // CHECK:   %[[KSIZE:.+]] = arith.constant 4
+  // CHECK:   %[[START:.+]] = linalg.index 2
+  // CHECK:   %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]]
+  // CHECK:   %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
+  // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
+  // CHECK:   %[[PAD_START:.+]] = arith.constant 1
+  // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+  // CHECK:   %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
+  // CHECK:   %[[PAD_END:.+]] = arith.constant 1
+  // CHECK:   %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+  // CHECK:   %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
+  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
+  // CHECK:   %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+
+  // Divide the summed value by the number of values summed.
+  // CHECK:   %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[COUNT]]
+  // CHECK:   %[[FLT:.+]] = arith.sitofp %[[CAST]]
+  // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
+  // CHECK:   %[[TRUNC:.+]] = arith.truncf %[[DIV]]
+  // CHECK:   linalg.yield %[[TRUNC]]
+  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
+  return %0 : tensor<1x5x33x62xf16>
+}
+
+// -----
+
 // CHECK-LABEL: @avg_pool_i8
 func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[GENERIC:.+]] = linalg.generic

@RoboTux RoboTux merged commit 88c830a into llvm:main Feb 7, 2024
@RoboTux RoboTux deleted the fix_avgpool2d_fp32_fp16 branch April 18, 2024 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants