Skip to content

Commit 7bc9362

Browse files
committed
[mlir][tosa] Enhance folder for Tosa binary operators
This enhances folder for tosa binary operators to support non-splat constant attributes for following ops: - mul - add - sub - greater - greater_equal - equal
1 parent fefb685 commit 7bc9362

File tree

3 files changed

+370
-40
lines changed

3 files changed

+370
-40
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 153 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -501,15 +501,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
501501
// Operator Folders.
502502
//===----------------------------------------------------------------------===//
503503

504-
template <typename IntFolder, typename FloatFolder>
504+
template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
505505
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
506506
RankedTensorType returnTy) {
507-
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
508-
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
509-
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
510-
if (lETy != rETy)
511-
return {};
507+
if (!rhs || !lhs)
508+
return {};
509+
510+
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
511+
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
512+
if (lETy != rETy)
513+
return {};
514+
515+
if (!lETy.isIntOrFloat())
516+
return {};
512517

518+
if (rhs.isSplat() && lhs.isSplat()) {
513519
if (llvm::isa<IntegerType>(lETy)) {
514520
APInt l = lhs.getSplatValue<APInt>();
515521
APInt r = rhs.getSplatValue<APInt>();
@@ -525,9 +531,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
525531
}
526532
}
527533

534+
auto lhsCount = lhs.getNumElements();
535+
auto rhsCount = rhs.getNumElements();
536+
if (lhsCount != rhsCount)
537+
return {};
538+
539+
// to prevent long compile time, skip if too many elements
540+
if (lhsCount > 128)
541+
return {};
542+
543+
if (llvm::isa<IntegerType>(lETy)) {
544+
auto lvalues = lhs.getValues<APInt>();
545+
auto rvalues = rhs.getValues<APInt>();
546+
SmallVector<APInt> results;
547+
IntFolder intFolder{};
548+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
549+
auto result = intFolder(l, r);
550+
results.push_back(result);
551+
}
552+
return DenseElementsAttr::get(returnTy, results);
553+
}
554+
555+
if (llvm::isa<FloatType>(lETy)) {
556+
auto lvalues = lhs.getValues<APFloat>();
557+
auto rvalues = rhs.getValues<APFloat>();
558+
// FloatFolder() may return either APFloat or APInt (comparison functions)
559+
SmallVector<FloatResultAPType> results;
560+
FloatFolder floatFolder{};
561+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
562+
auto result = floatFolder(l, r);
563+
results.push_back(result);
564+
}
565+
return DenseElementsAttr::get(returnTy, results);
566+
}
567+
528568
return {};
529569
}
530570

571+
template <typename IntFolder, typename FloatFolder>
572+
DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
573+
DenseElementsAttr rhs,
574+
RankedTensorType returnTy) {
575+
// comparison FloatFolder() functions return APInt values
576+
return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
577+
}
578+
579+
template <typename IntFolder, typename FloatFolder>
580+
DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
581+
DenseElementsAttr rhs,
582+
RankedTensorType returnTy) {
583+
// arithmetic FloatFolder() functions return APFloat values
584+
return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
585+
}
586+
531587
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
532588
if (llvm::isa<FloatType>(elemType))
533589
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -574,8 +630,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
574630
if (!lhsAttr || !rhsAttr)
575631
return {};
576632

577-
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
578-
resultTy);
633+
return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
634+
lhsAttr, rhsAttr, resultTy);
579635
}
580636

581637
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -632,32 +688,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
632688
}
633689

634690
namespace {
691+
692+
// calculate lhs * rhs >> shift according to TOSA Spec
693+
// return nullopt if result is not in range of int32_t when shift > 0
694+
std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
695+
unsigned bitwidth) {
696+
APInt result = lhs.sext(64) * rhs.sext(64);
697+
698+
if (shift > 0) {
699+
auto round = APInt(64, 1) << (shift - 1);
700+
result += round;
701+
result.ashrInPlace(shift);
702+
// REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
703+
if (!(result.getSExtValue() >= INT32_MIN &&
704+
result.getSExtValue() <= INT32_MAX)) {
705+
// REQUIRE failed
706+
return std::nullopt;
707+
}
708+
}
709+
710+
return result.trunc(bitwidth);
711+
}
712+
635713
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
636714
RankedTensorType ty, int32_t shift) {
637-
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
638-
if (llvm::isa<IntegerType>(ty.getElementType())) {
639-
APInt l = lhs.getSplatValue<APInt>();
640-
APInt r = rhs.getSplatValue<APInt>();
715+
if (!lhs || !rhs)
716+
return {};
717+
718+
// REQUIRE(0 <= shift && shift <= 63);
719+
if (!(0 <= shift && shift <= 63))
720+
return {};
721+
722+
auto elementType = ty.getElementType();
723+
if (!elementType.isIntOrFloat())
724+
return {};
641725

642-
if (shift == 0) {
643-
return DenseElementsAttr::get(ty, l * r);
726+
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
727+
// REQUIRE(in_t == int32_t || shift == 0);
728+
if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
729+
return {};
730+
731+
if (rhs.isSplat() && lhs.isSplat()) {
732+
if (llvm::isa<IntegerType>(elementType)) {
733+
auto l = lhs.getSplatValue<APInt>();
734+
auto r = rhs.getSplatValue<APInt>();
735+
736+
if (auto result = mulInt(l, r, shift, bitwidth)) {
737+
return DenseElementsAttr::get(ty, result.value());
644738
}
739+
// mulInt failed
740+
return {};
741+
}
645742

646-
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
647-
l = l.sext(bitwidth * 2);
648-
r = r.sext(bitwidth * 2);
743+
if (llvm::isa<FloatType>(elementType)) {
744+
auto l = lhs.getSplatValue<APFloat>();
745+
auto r = rhs.getSplatValue<APFloat>();
649746
auto result = l * r;
650-
result.lshrInPlace(shift);
651-
result = result.trunc(bitwidth);
652747
return DenseElementsAttr::get(ty, result);
653748
}
749+
}
654750

655-
if (llvm::isa<FloatType>(ty.getElementType())) {
656-
APFloat l = lhs.getSplatValue<APFloat>();
657-
APFloat r = rhs.getSplatValue<APFloat>();
658-
APFloat result = l * r;
659-
return DenseElementsAttr::get(ty, result);
751+
if (llvm::isa<IntegerType>(elementType)) {
752+
auto lvalues = lhs.getValues<APInt>();
753+
auto rvalues = rhs.getValues<APInt>();
754+
if (lvalues.size() != rvalues.size()) {
755+
return {};
756+
}
757+
SmallVector<APInt> results;
758+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
759+
if (auto result = mulInt(l, r, shift, bitwidth)) {
760+
results.push_back(result.value());
761+
continue;
762+
}
763+
// mulInt failed
764+
return {};
765+
}
766+
return DenseElementsAttr::get(ty, results);
767+
}
768+
769+
if (llvm::isa<FloatType>(elementType)) {
770+
auto lvalues = lhs.getValues<APFloat>();
771+
auto rvalues = rhs.getValues<APFloat>();
772+
if (lvalues.size() != rvalues.size()) {
773+
return {};
660774
}
775+
SmallVector<APFloat> results;
776+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
777+
auto result = l * r;
778+
results.push_back(result);
779+
}
780+
return DenseElementsAttr::get(ty, results);
661781
}
662782

663783
return {};
@@ -732,8 +852,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
732852
if (!lhsAttr || !rhsAttr)
733853
return {};
734854

735-
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
736-
resultTy);
855+
return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
856+
lhsAttr, rhsAttr, resultTy);
737857
}
738858

739859
namespace {
@@ -774,7 +894,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
774894
if (!lhsAttr || !rhsAttr)
775895
return {};
776896

777-
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
897+
return comparisonBinaryFolder<APIntFoldGreater,
898+
ComparisonFold<std::greater<APFloat>>>(
778899
lhsAttr, rhsAttr, resultTy);
779900
}
780901

@@ -788,8 +909,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
788909
if (!lhsAttr || !rhsAttr)
789910
return {};
790911

791-
return binaryFolder<APIntFoldGreaterEqual,
792-
ComparisonFold<std::greater_equal<APFloat>>>(
912+
return comparisonBinaryFolder<APIntFoldGreaterEqual,
913+
ComparisonFold<std::greater_equal<APFloat>>>(
793914
lhsAttr, rhsAttr, resultTy);
794915
}
795916

@@ -813,9 +934,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
813934
if (!lhsAttr || !rhsAttr)
814935
return {};
815936

816-
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
817-
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
818-
resultTy);
937+
return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
938+
ComparisonFold<std::equal_to<APFloat>>>(
939+
lhsAttr, rhsAttr, resultTy);
819940
}
820941

821942
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,11 +1082,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
10821082

10831083
func.func @reduce_sum_constant() -> tensor<1x3xi32> {
10841084
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
1085-
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1086-
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1087-
// CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
1088-
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
1089-
// CHECK: return %[[VAL_3]] : tensor<1x3xi32>
1085+
// CHECK: %[[K:.*]] = "tosa.const"() <{values = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
1086+
// CHECK: return %[[K]] : tensor<1x3xi32>
10901087
%arg0 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
10911088
%arg1 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
10921089
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

0 commit comments

Comments
 (0)