Skip to content

[mlir][tosa] Enhance folder for Tosa binary operators #128059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

Tai78641
Copy link
Contributor

This enhances folder for tosa binary operators to support non-splat constant attributes for following ops:

  • mul
  • add
  • sub
  • greater
  • greater_equal
  • equal

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This enhances folder for tosa binary operators to support non-splat constant attributes for following ops:

  • mul
  • add
  • sub
  • greater
  • greater_equal
  • equal

Patch is 24.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128059.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+148-32)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-5)
  • (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+215-3)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..2c6c6e2ed284c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Operator Folders.
 //===----------------------------------------------------------------------===//
 
-template <typename IntFolder, typename FloatFolder>
+template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
 DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                RankedTensorType returnTy) {
-  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
-    if (lETy != rETy)
-      return {};
+  if (!rhs || !lhs)
+    return {};
+
+  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+  if (lETy != rETy)
+    return {};
+
+  if (!lETy.isIntOrFloat())
+    return {};
 
+  if (rhs.isSplat() && lhs.isSplat()) {
     if (llvm::isa<IntegerType>(lETy)) {
       APInt l = lhs.getSplatValue<APInt>();
       APInt r = rhs.getSplatValue<APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
     }
   }
 
+  if (llvm::isa<IntegerType>(lETy)) {
+    auto lvalues = lhs.getValues<APInt>();
+    auto rvalues = rhs.getValues<APInt>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
+    }
+    SmallVector<APInt> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      auto result = IntFolder()(l, r);
+      results.push_back(result);
+    }
+    return DenseElementsAttr::get(returnTy, results);
+  }
+
+  if (llvm::isa<FloatType>(lETy)) {
+    auto lvalues = lhs.getValues<APFloat>();
+    auto rvalues = rhs.getValues<APFloat>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
+    }
+    // FloatFolder() may return either APFloat or APInt (comparison functions)
+    SmallVector<FloatResultAPType> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      auto result = FloatFolder()(l, r);
+      results.push_back(result);
+    }
+    return DenseElementsAttr::get(returnTy, results);
+  }
+
   return {};
 }
 
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
+                                         DenseElementsAttr rhs,
+                                         RankedTensorType returnTy) {
+  // comparison FloatFolder() functions return APInt values
+  return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
+}
+
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
+                                         DenseElementsAttr rhs,
+                                         RankedTensorType returnTy) {
+  // arithmetic FloatFolder() functions return APFloat values
+  return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
+}
+
 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
   if (llvm::isa<FloatType>(elemType))
     return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
-                                                            resultTy);
+  return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
+      lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 }
 
 namespace {
+
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+                            unsigned bitwidth) {
+  APInt result = lhs.sext(64) * rhs.sext(64);
+
+  if (shift > 0) {
+    auto round = APInt(64, 1) << (shift - 1);
+    result += round;
+    result.ashrInPlace(shift);
+    // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+    if (!(result.getSExtValue() >= INT32_MIN &&
+          result.getSExtValue() <= INT32_MAX)) {
+      // REQUIRE failed
+      return std::nullopt;
+    }
+  }
+
+  return result.trunc(bitwidth);
+}
+
 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                   RankedTensorType ty, int32_t shift) {
-  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    if (llvm::isa<IntegerType>(ty.getElementType())) {
-      APInt l = lhs.getSplatValue<APInt>();
-      APInt r = rhs.getSplatValue<APInt>();
+  if (!lhs || !rhs)
+    return {};
+
+  // REQUIRE(0 <= shift && shift <= 63);
+  if (!(0 <= shift && shift <= 63))
+    return {};
 
-      if (shift == 0) {
-        return DenseElementsAttr::get(ty, l * r);
+  auto elementType = ty.getElementType();
+  if (!elementType.isIntOrFloat())
+    return {};
+
+  unsigned bitwidth = elementType.getIntOrFloatBitWidth();
+  // REQUIRE(in_t == int32_t || shift == 0);
+  if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
+    return {};
+
+  if (rhs.isSplat() && lhs.isSplat()) {
+    if (llvm::isa<IntegerType>(elementType)) {
+      auto l = lhs.getSplatValue<APInt>();
+      auto r = rhs.getSplatValue<APInt>();
+
+      if (auto result = mulInt(l, r, shift, bitwidth)) {
+        return DenseElementsAttr::get(ty, result.value());
       }
+      // mulInt failed
+      return {};
+    }
 
-      auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
-      l = l.sext(bitwidth * 2);
-      r = r.sext(bitwidth * 2);
+    if (llvm::isa<FloatType>(elementType)) {
+      auto l = lhs.getSplatValue<APFloat>();
+      auto r = rhs.getSplatValue<APFloat>();
       auto result = l * r;
-      result.lshrInPlace(shift);
-      result = result.trunc(bitwidth);
       return DenseElementsAttr::get(ty, result);
     }
+  }
+
+  if (llvm::isa<IntegerType>(elementType)) {
+    auto lvalues = lhs.getValues<APInt>();
+    auto rvalues = rhs.getValues<APInt>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
+    }
+    SmallVector<APInt> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      if (auto result = mulInt(l, r, shift, bitwidth)) {
+        results.push_back(result.value());
+        continue;
+      }
+      // mulInt failed
+      return {};
+    }
+    return DenseElementsAttr::get(ty, results);
+  }
 
-    if (llvm::isa<FloatType>(ty.getElementType())) {
-      APFloat l = lhs.getSplatValue<APFloat>();
-      APFloat r = rhs.getSplatValue<APFloat>();
-      APFloat result = l * r;
-      return DenseElementsAttr::get(ty, result);
+  if (llvm::isa<FloatType>(elementType)) {
+    auto lvalues = lhs.getValues<APFloat>();
+    auto rvalues = rhs.getValues<APFloat>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
     }
+    SmallVector<APFloat> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      auto result = l * r;
+      results.push_back(result);
+    }
+    return DenseElementsAttr::get(ty, results);
   }
 
   return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
+      lhsAttr, rhsAttr, resultTy);
 }
 
 namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+  return comparisonBinaryFolder<APIntFoldGreater,
+                                ComparisonFold<std::greater<APFloat>>>(
       lhsAttr, rhsAttr, resultTy);
 }
 
@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreaterEqual,
-                      ComparisonFold<std::greater_equal<APFloat>>>(
+  return comparisonBinaryFolder<APIntFoldGreaterEqual,
+                                ComparisonFold<std::greater_equal<APFloat>>>(
       lhsAttr, rhsAttr, resultTy);
 }
 
@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
-                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+                                ComparisonFold<std::equal_to<APFloat>>>(
+      lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..5aab368fa044d 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1092,11 +1092,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
 
 func.func @reduce_sum_constant() -> tensor<1x3xi32> {
   // CHECK-LABEL:     func.func @reduce_sum_constant() -> tensor<1x3xi32> {
-  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
-  // CHECK:           return %[[VAL_3]] : tensor<1x3xi32>
+  // CHECK:           %[[K:.*]] = "tosa.const"() <{value = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+  // CHECK:           return %[[K]] : tensor<1x3xi32>
   %arg0 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
   %arg1 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
   %arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..fee1ce7793b12 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize --test-constant-fold %s | FileCheck %s
+
+// -----
 
 // CHECK-LABEL: func @test_const
 func.func @test_const(%arg0 : index) -> tensor<4xi32> {
@@ -7,6 +9,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
   return %0 : tensor<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @test_const_i64
 func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
   // CHECK: tosa.const
@@ -14,10 +18,218 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
   return %0 : tensor<4xi64>
 }
 
+// -----
+
 // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<*xi1> {
   // CHECK: tosa.equal
   // CHECK-NEXT: return
   %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
-  return
+  return %0 : tensor<*xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32
+func.func @test_mul_i32() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[9, 36, 36, 81]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %lhs = "tosa.const"() {value = dense<[1, 2, -2, -3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+  %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32_shift
+func.func @test_mul_i32_shift() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2550, 8100, 2, 2025]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %lhs = "tosa.const"() {value = dense<[135, 240, -4, -120]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+  %shift = "tosa.const"() { value = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+  %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_f32
+func.func @test_mul_f32() -> tensor<4xf32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2.304000e+01, 58.9824028, 1.6384002, 14.7456007]> : tensor<4xf32>}>
+  // CHECK: return %[[VAL]]
+  %lhs = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %rhs = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+  %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+  %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+  %result = tosa.mul %x, %y, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+  return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_f32
+func.func @test_add_f32() -> tensor<4xf32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[7.500000e+00, 9.300000e+00, 3.69999981, 2.100000e+00]> : tensor<4xf32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.add %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %y = tosa.add %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %result = tosa.add %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_i32
+func.func @test_add_i32() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[75, 93, 37, 21]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.add %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %y = tosa.add %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %result = tosa.add %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_f32
+func.func @test_sub_f32() -> tensor<4xf32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-1.500000e+00, 0.300000191, -5.300000e+00, -6.900000e+00]> : tensor<4xf32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.sub %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %y = tosa.sub %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %result = tosa.sub %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_i32
+func.func @test_sub_i32() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-15, 3, -53, -69]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.sub %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %y = tosa.sub %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %result = tosa.sub %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_f32
+func.func @test_greater_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+  %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+  %cst2 = "tosa.const"() {value = dense<[1.7, 2.3, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.greater %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %y = tosa.greater %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %z = tosa.greater %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_i32
+func.func @test_greater_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+  %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %cst2 = "tosa.const"() {value = dense<[17, 23, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat = "tosa.const"() {value = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.greater %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %y = tosa.greater %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %z = tosa.greater %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_f32
+func.func @test_greater_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.c...
[truncated]

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's worth limiting the size of the binary operations that are folded - would the runtime of folding a larger input, say tensor<1024x1024x1024x1024xf32>, be reasonable?

@Tai78641
Copy link
Contributor Author

I wonder if it's worth limiting the size of the binary operations that are folded - would the runtime of folding a larger input, say tensor<1024x1024x1024x1024xf32>, be reasonable?

I looked in the code base for examples of limiting fold by number of elements, with no luck.
And I found that TFLite op's ConstFoldBinaryOp imposes no limit on size of number of elements

@lhutton1
Copy link
Contributor

I'm curious if there's a use-case you had in mind for these? In the past, there's been changes that operate in the opposite direction e.g. https://reviews.llvm.org/D124685 due to compilation time

@Tai78641
Copy link
Contributor Author

I'm curious if there's a use-case you had in mind for these? In the past, there's been changes that operate in the opposite direction e.g. https://reviews.llvm.org/D124685 due to compilation time

these are added to collapse constant expressions we encounter in shape expressions when dynamic shapes are specialized to specific values.

@GeorgeARM
Copy link
Contributor

I share Luke's concern here. This tensor folding could end being quite expensive and not explicitly controllable.
So options are to:

  • have explicit shape operations where we do the folding
  • add some size heuristics on the folders (this shall be relatively straight-forward if we mainly care about shape folding)
  • pull the out to some kind of optimization pass
  • other thoughts?

@Tai78641
Copy link
Contributor Author

I share Luke's concern here. This tensor folding could end being quite expensive and not explicitly controllable. So options are to:

  • have explicit shape operations where we do the folding
  • add some size heuristics on the folders (this shall be relatively straight-forward if we mainly care about shape folding)
  • pull the out to some kind of optimization pass
  • other thoughts?

added a cap at 128 elements in binaryFolder

if (llvm::isa<IntegerType>(lETy)) {
auto lvalues = lhs.getValues<APInt>();
auto rvalues = rhs.getValues<APInt>();
if (lvalues.size() != rvalues.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the llvm style guide this if statement shouldn't have curly braces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if (llvm::isa<FloatType>(lETy)) {
auto lvalues = lhs.getValues<APFloat>();
auto rvalues = rhs.getValues<APFloat>();
if (lvalues.size() != rvalues.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the llvm style guide this if statement shouldn't have curly braces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
SmallVector<APInt> results;
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
auto result = IntFolder()(l, r);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this constucting an IntFolder on each loop iteration just to call an operator overload? If so could it be hoisted out the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


const int64_t MAX_ELEMENT_COUNT = 128;
if (lhsCount > MAX_ELEMENT_COUNT) {
// to prevent long compile time, skip if too many elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we hoist this comment above the if we can probably remove the MAX_ELEMENT_COUNT variable completely since it'll be obvious from the comment what the magic number 128 means. We can then also remove the {} on the if statement as per the llvm style guide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

// FloatFolder() may return either APFloat or APInt (comparison functions)
SmallVector<FloatResultAPType> results;
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
auto result = FloatFolder()(l, r);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this constructing a FloatFolder on each loop iteration just to call an operator overload?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@Tai78641 Tai78641 force-pushed the pr_enhance_folder branch from 2206984 to b37535a Compare March 4, 2025 19:12
@Tai78641 Tai78641 requested a review from FranklandJack March 4, 2025 19:13
@Tai78641 Tai78641 force-pushed the pr_enhance_folder branch 2 times, most recently from 558851b to 7bc9362 Compare March 7, 2025 20:07
This enhances folder for tosa binary operators to support non-splat constant attributes for
following ops:
 - mul
 - add
 - sub
 - greater
 - greater_equal
 - equal
@Tai78641 Tai78641 force-pushed the pr_enhance_folder branch from 7bc9362 to a579d4c Compare March 7, 2025 21:08
lhutton1 added a commit to lhutton1/llvm-project that referenced this pull request Apr 28, 2025
Change the folder for mul with a shift such that the
rounding happens correctly according to the spec
pesudo-code.

Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from: llvm#128059

Co-authored-by: Tai Ly <[email protected]>
Change-Id: I3a8cf816cbf71c9ab839eb4f52768904cea29935
lhutton1 added a commit that referenced this pull request May 8, 2025
Change the folder for mul with a shift such that the rounding happens
correctly according to the spec
pesudo-code.

Fixes:
https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from:
#128059

Co-authored-by: Tai Ly <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request May 8, 2025
Change the folder for mul with a shift such that the rounding happens
correctly according to the spec
pesudo-code.

Fixes:
https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040
Partial cherry-pick from:
llvm/llvm-project#128059

Co-authored-by: Tai Ly <[email protected]>
@Tai78641 Tai78641 closed this Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants