-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Enhance folder for Tosa binary operators #128059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis enhances folder for tosa binary operators to support non-splat constant attributes for following ops:
Patch is 24.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128059.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..2c6c6e2ed284c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//
-template <typename IntFolder, typename FloatFolder>
+template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
- if (lETy != rETy)
- return {};
+ if (!rhs || !lhs)
+ return {};
+
+ auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ if (lETy != rETy)
+ return {};
+
+ if (!lETy.isIntOrFloat())
+ return {};
+ if (rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
}
+ if (llvm::isa<IntegerType>(lETy)) {
+ auto lvalues = lhs.getValues<APInt>();
+ auto rvalues = rhs.getValues<APInt>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ SmallVector<APInt> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = IntFolder()(l, r);
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(returnTy, results);
+ }
+
+ if (llvm::isa<FloatType>(lETy)) {
+ auto lvalues = lhs.getValues<APFloat>();
+ auto rvalues = rhs.getValues<APFloat>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
+ SmallVector<FloatResultAPType> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = FloatFolder()(l, r);
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(returnTy, results);
+ }
+
return {};
}
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
+ // comparison FloatFolder() functions return APInt values
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
+}
+
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
+ // arithmetic FloatFolder() functions return APFloat values
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
+}
+
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
+ lhsAttr, rhsAttr, resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}
namespace {
+
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+ unsigned bitwidth) {
+ APInt result = lhs.sext(64) * rhs.sext(64);
+
+ if (shift > 0) {
+ auto round = APInt(64, 1) << (shift - 1);
+ result += round;
+ result.ashrInPlace(shift);
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ if (!(result.getSExtValue() >= INT32_MIN &&
+ result.getSExtValue() <= INT32_MAX)) {
+ // REQUIRE failed
+ return std::nullopt;
+ }
+ }
+
+ return result.trunc(bitwidth);
+}
+
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- if (llvm::isa<IntegerType>(ty.getElementType())) {
- APInt l = lhs.getSplatValue<APInt>();
- APInt r = rhs.getSplatValue<APInt>();
+ if (!lhs || !rhs)
+ return {};
+
+ // REQUIRE(0 <= shift && shift <= 63);
+ if (!(0 <= shift && shift <= 63))
+ return {};
- if (shift == 0) {
- return DenseElementsAttr::get(ty, l * r);
+ auto elementType = ty.getElementType();
+ if (!elementType.isIntOrFloat())
+ return {};
+
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth();
+ // REQUIRE(in_t == int32_t || shift == 0);
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
+ return {};
+
+ if (rhs.isSplat() && lhs.isSplat()) {
+ if (llvm::isa<IntegerType>(elementType)) {
+ auto l = lhs.getSplatValue<APInt>();
+ auto r = rhs.getSplatValue<APInt>();
+
+ if (auto result = mulInt(l, r, shift, bitwidth)) {
+ return DenseElementsAttr::get(ty, result.value());
}
+ // mulInt failed
+ return {};
+ }
- auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
- l = l.sext(bitwidth * 2);
- r = r.sext(bitwidth * 2);
+ if (llvm::isa<FloatType>(elementType)) {
+ auto l = lhs.getSplatValue<APFloat>();
+ auto r = rhs.getSplatValue<APFloat>();
auto result = l * r;
- result.lshrInPlace(shift);
- result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
}
+ }
+
+ if (llvm::isa<IntegerType>(elementType)) {
+ auto lvalues = lhs.getValues<APInt>();
+ auto rvalues = rhs.getValues<APInt>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ SmallVector<APInt> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ if (auto result = mulInt(l, r, shift, bitwidth)) {
+ results.push_back(result.value());
+ continue;
+ }
+ // mulInt failed
+ return {};
+ }
+ return DenseElementsAttr::get(ty, results);
+ }
- if (llvm::isa<FloatType>(ty.getElementType())) {
- APFloat l = lhs.getSplatValue<APFloat>();
- APFloat r = rhs.getSplatValue<APFloat>();
- APFloat result = l * r;
- return DenseElementsAttr::get(ty, result);
+ if (llvm::isa<FloatType>(elementType)) {
+ auto lvalues = lhs.getValues<APFloat>();
+ auto rvalues = rhs.getValues<APFloat>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
}
+ SmallVector<APFloat> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = l * r;
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(ty, results);
}
return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
+ lhsAttr, rhsAttr, resultTy);
}
namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+ return comparisonBinaryFolder<APIntFoldGreater,
+ ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreaterEqual,
- ComparisonFold<std::greater_equal<APFloat>>>(
+ return comparisonBinaryFolder<APIntFoldGreaterEqual,
+ ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
- resultTy);
+ return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+ ComparisonFold<std::equal_to<APFloat>>>(
+ lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..5aab368fa044d 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1092,11 +1092,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
func.func @reduce_sum_constant() -> tensor<1x3xi32> {
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
- // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
- // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
- // CHECK: return %[[VAL_3]] : tensor<1x3xi32>
+ // CHECK: %[[K:.*]] = "tosa.const"() <{value = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+ // CHECK: return %[[K]] : tensor<1x3xi32>
%arg0 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg1 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..fee1ce7793b12 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize --test-constant-fold %s | FileCheck %s
+
+// -----
// CHECK-LABEL: func @test_const
func.func @test_const(%arg0 : index) -> tensor<4xi32> {
@@ -7,6 +9,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
+// -----
+
// CHECK-LABEL: func @test_const_i64
func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
// CHECK: tosa.const
@@ -14,10 +18,218 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
return %0 : tensor<4xi64>
}
+// -----
+
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<*xi1> {
// CHECK: tosa.equal
// CHECK-NEXT: return
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
- return
+ return %0 : tensor<*xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32
+func.func @test_mul_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[9, 36, 36, 81]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {value = dense<[1, 2, -2, -3]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+ %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32_shift
+func.func @test_mul_i32_shift() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2550, 8100, 2, 2025]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {value = dense<[135, 240, -4, -120]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+ %shift = "tosa.const"() { value = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_f32
+func.func @test_mul_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2.304000e+01, 58.9824028, 1.6384002, 14.7456007]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %rhs = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_f32
+func.func @test_add_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[7.500000e+00, 9.300000e+00, 3.69999981, 2.100000e+00]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.add %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %y = tosa.add %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = tosa.add %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_i32
+func.func @test_add_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[75, 93, 37, 21]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.add %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %y = tosa.add %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %result = tosa.add %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_f32
+func.func @test_sub_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-1.500000e+00, 0.300000191, -5.300000e+00, -6.900000e+00]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.sub %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %y = tosa.sub %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = tosa.sub %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_i32
+func.func @test_sub_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-15, 3, -53, -69]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.sub %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %y = tosa.sub %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %result = tosa.sub %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_f32
+func.func @test_greater_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {value = dense<[1.7, 2.3, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.greater %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.greater %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.greater %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_i32
+func.func @test_greater_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {value = dense<[17, 23, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {value = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.greater %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.greater %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.greater %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_f32
+func.func @test_greater_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.c...
[truncated]
|
c4bbd39
to
d050d54
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it's worth limiting the size of the binary operations that are folded - would the runtime of folding a larger input, say tensor<1024x1024x1024x1024xf32>
, be reasonable?
I looked in the code base for examples of limiting fold by number of elements, with no luck. |
I'm curious if there's a use-case you had in mind for these? In the past, there's been changes that operate in the opposite direction e.g. https://reviews.llvm.org/D124685 due to compilation time |
these are added to collapse constant expressions we encounter in shape expressions when dynamic shapes are specialized to specific values. |
I share Luke's concern here. This tensor folding could end being quite expensive and not explicitly controllable.
|
d050d54
to
8b25dd5
Compare
added a cap at 128 elements in binaryFolder |
8b25dd5
to
2206984
Compare
if (llvm::isa<IntegerType>(lETy)) { | ||
auto lvalues = lhs.getValues<APInt>(); | ||
auto rvalues = rhs.getValues<APInt>(); | ||
if (lvalues.size() != rvalues.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the llvm style guide this if statement shouldn't have curly braces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
if (llvm::isa<FloatType>(lETy)) { | ||
auto lvalues = lhs.getValues<APFloat>(); | ||
auto rvalues = rhs.getValues<APFloat>(); | ||
if (lvalues.size() != rvalues.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the llvm style guide this if statement shouldn't have curly braces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
} | ||
SmallVector<APInt> results; | ||
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) { | ||
auto result = IntFolder()(l, r); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this constucting an IntFolder
on each loop iteration just to call an operator overload? If so could it be hoisted out the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
const int64_t MAX_ELEMENT_COUNT = 128; | ||
if (lhsCount > MAX_ELEMENT_COUNT) { | ||
// to prevent long compile time, skip if too many elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we hoist this comment above the if
we can probably remove the MAX_ELEMENT_COUNT
variable completely since it'll be obvious from the comment what the magic number 128
means. We can then also remove the {}
on the if
statement as per the llvm style guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
// FloatFolder() may return either APFloat or APInt (comparison functions) | ||
SmallVector<FloatResultAPType> results; | ||
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) { | ||
auto result = FloatFolder()(l, r); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this constructing a FloatFolder
on each loop iteration just to call an operator overload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
2206984
to
b37535a
Compare
558851b
to
7bc9362
Compare
This enhances folder for tosa binary operators to support non-splat constant attributes for following ops: - mul - add - sub - greater - greater_equal - equal
7bc9362
to
a579d4c
Compare
Change the folder for mul with a shift such that the rounding happens correctly according to the spec pesudo-code. Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040 Partial cherry-pick from: llvm#128059 Co-authored-by: Tai Ly <[email protected]> Change-Id: I3a8cf816cbf71c9ab839eb4f52768904cea29935
Change the folder for mul with a shift such that the rounding happens correctly according to the spec pesudo-code. Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040 Partial cherry-pick from: #128059 Co-authored-by: Tai Ly <[email protected]>
Change the folder for mul with a shift such that the rounding happens correctly according to the spec pesudo-code. Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040 Partial cherry-pick from: llvm/llvm-project#128059 Co-authored-by: Tai Ly <[email protected]>
This enhances folder for tosa binary operators to support non-splat constant attributes for following ops: