Skip to content

[mlir][tosa] Enhance folder for Tosa binary operators #128059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 153 additions & 32 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,15 +501,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//

template <typename IntFolder, typename FloatFolder>
template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};
if (!rhs || !lhs)
return {};

auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};

if (!lETy.isIntOrFloat())
return {};

if (rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
Expand All @@ -525,9 +531,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
}

auto lhsCount = lhs.getNumElements();
auto rhsCount = rhs.getNumElements();
if (lhsCount != rhsCount)
return {};

// to prevent long compile time, skip if too many elements
if (lhsCount > 128)
return {};

if (llvm::isa<IntegerType>(lETy)) {
auto lvalues = lhs.getValues<APInt>();
auto rvalues = rhs.getValues<APInt>();
SmallVector<APInt> results;
IntFolder intFolder{};
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
auto result = intFolder(l, r);
results.push_back(result);
}
return DenseElementsAttr::get(returnTy, results);
}

if (llvm::isa<FloatType>(lETy)) {
auto lvalues = lhs.getValues<APFloat>();
auto rvalues = rhs.getValues<APFloat>();
// FloatFolder() may return either APFloat or APInt (comparison functions)
SmallVector<FloatResultAPType> results;
FloatFolder floatFolder{};
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
auto result = floatFolder(l, r);
results.push_back(result);
}
return DenseElementsAttr::get(returnTy, results);
}

return {};
}

template <typename IntFolder, typename FloatFolder>
DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
DenseElementsAttr rhs,
RankedTensorType returnTy) {
// comparison FloatFolder() functions return APInt values
return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
}

template <typename IntFolder, typename FloatFolder>
DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
DenseElementsAttr rhs,
RankedTensorType returnTy) {
// arithmetic FloatFolder() functions return APFloat values
return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
}

static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
Expand Down Expand Up @@ -574,8 +630,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
resultTy);
return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
lhsAttr, rhsAttr, resultTy);
}

OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -632,32 +688,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}

namespace {

// calculate lhs * rhs >> shift according to TOSA Spec
// return nullopt if result is not in range of int32_t when shift > 0
std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
unsigned bitwidth) {
APInt result = lhs.sext(64) * rhs.sext(64);

if (shift > 0) {
auto round = APInt(64, 1) << (shift - 1);
result += round;
result.ashrInPlace(shift);
// REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
if (!(result.getSExtValue() >= INT32_MIN &&
result.getSExtValue() <= INT32_MAX)) {
// REQUIRE failed
return std::nullopt;
}
}

return result.trunc(bitwidth);
}

DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(ty.getElementType())) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
if (!lhs || !rhs)
return {};

// REQUIRE(0 <= shift && shift <= 63);
if (!(0 <= shift && shift <= 63))
return {};

auto elementType = ty.getElementType();
if (!elementType.isIntOrFloat())
return {};

if (shift == 0) {
return DenseElementsAttr::get(ty, l * r);
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
// REQUIRE(in_t == int32_t || shift == 0);
if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
return {};

if (rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(elementType)) {
auto l = lhs.getSplatValue<APInt>();
auto r = rhs.getSplatValue<APInt>();

if (auto result = mulInt(l, r, shift, bitwidth)) {
return DenseElementsAttr::get(ty, result.value());
}
// mulInt failed
return {};
}

auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
l = l.sext(bitwidth * 2);
r = r.sext(bitwidth * 2);
if (llvm::isa<FloatType>(elementType)) {
auto l = lhs.getSplatValue<APFloat>();
auto r = rhs.getSplatValue<APFloat>();
auto result = l * r;
result.lshrInPlace(shift);
result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
}
}

if (llvm::isa<FloatType>(ty.getElementType())) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
APFloat result = l * r;
return DenseElementsAttr::get(ty, result);
if (llvm::isa<IntegerType>(elementType)) {
auto lvalues = lhs.getValues<APInt>();
auto rvalues = rhs.getValues<APInt>();
if (lvalues.size() != rvalues.size()) {
return {};
}
SmallVector<APInt> results;
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
if (auto result = mulInt(l, r, shift, bitwidth)) {
results.push_back(result.value());
continue;
}
// mulInt failed
return {};
}
return DenseElementsAttr::get(ty, results);
}

if (llvm::isa<FloatType>(elementType)) {
auto lvalues = lhs.getValues<APFloat>();
auto rvalues = rhs.getValues<APFloat>();
if (lvalues.size() != rvalues.size()) {
return {};
}
SmallVector<APFloat> results;
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
auto result = l * r;
results.push_back(result);
}
return DenseElementsAttr::get(ty, results);
}

return {};
Expand Down Expand Up @@ -732,8 +852,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
resultTy);
return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
lhsAttr, rhsAttr, resultTy);
}

namespace {
Expand Down Expand Up @@ -774,7 +894,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
return comparisonBinaryFolder<APIntFoldGreater,
ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}

Expand All @@ -788,8 +909,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<APIntFoldGreaterEqual,
ComparisonFold<std::greater_equal<APFloat>>>(
return comparisonBinaryFolder<APIntFoldGreaterEqual,
ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}

Expand All @@ -813,9 +934,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
resultTy);
return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
ComparisonFold<std::equal_to<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}

OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
Expand Down
7 changes: 2 additions & 5 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {

func.func @reduce_sum_constant() -> tensor<1x3xi32> {
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
// CHECK: return %[[VAL_3]] : tensor<1x3xi32>
// CHECK: %[[K:.*]] = "tosa.const"() <{values = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
// CHECK: return %[[K]] : tensor<1x3xi32>
%arg0 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg1 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
Expand Down
Loading