Skip to content

Commit c4bbd39

Browse files
committed
[mlir][tosa] Enhance folder for Tosa binary operators
This enhances folder for tosa binary operators to support non-splat constant attributes for following ops: - mul - add - sub - greater - greater_equal - equal Signed-off-by: Tai Ly <[email protected]> Change-Id: I3198a808988a71b5894d8f7c410b407340564c38
1 parent df9d3c2 commit c4bbd39

File tree

3 files changed

+365
-40
lines changed

3 files changed

+365
-40
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 148 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
563563
// Operator Folders.
564564
//===----------------------------------------------------------------------===//
565565

566-
template <typename IntFolder, typename FloatFolder>
566+
template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
567567
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
568568
RankedTensorType returnTy) {
569-
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
570-
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
571-
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
572-
if (lETy != rETy)
573-
return {};
569+
if (!rhs || !lhs)
570+
return {};
571+
572+
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
573+
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
574+
if (lETy != rETy)
575+
return {};
576+
577+
if (!lETy.isIntOrFloat())
578+
return {};
574579

580+
if (rhs.isSplat() && lhs.isSplat()) {
575581
if (llvm::isa<IntegerType>(lETy)) {
576582
APInt l = lhs.getSplatValue<APInt>();
577583
APInt r = rhs.getSplatValue<APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
587593
}
588594
}
589595

596+
if (llvm::isa<IntegerType>(lETy)) {
597+
auto lvalues = lhs.getValues<APInt>();
598+
auto rvalues = rhs.getValues<APInt>();
599+
if (lvalues.size() != rvalues.size()) {
600+
return {};
601+
}
602+
SmallVector<APInt> results;
603+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
604+
auto result = IntFolder()(l, r);
605+
results.push_back(result);
606+
}
607+
return DenseElementsAttr::get(returnTy, results);
608+
}
609+
610+
if (llvm::isa<FloatType>(lETy)) {
611+
auto lvalues = lhs.getValues<APFloat>();
612+
auto rvalues = rhs.getValues<APFloat>();
613+
if (lvalues.size() != rvalues.size()) {
614+
return {};
615+
}
616+
// FloatFolder() may return either APFloat or APInt (comparison functions)
617+
SmallVector<FloatResultAPType> results;
618+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
619+
auto result = FloatFolder()(l, r);
620+
results.push_back(result);
621+
}
622+
return DenseElementsAttr::get(returnTy, results);
623+
}
624+
590625
return {};
591626
}
592627

628+
template <typename IntFolder, typename FloatFolder>
629+
DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
630+
DenseElementsAttr rhs,
631+
RankedTensorType returnTy) {
632+
// comparison FloatFolder() functions return APInt values
633+
return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
634+
}
635+
636+
template <typename IntFolder, typename FloatFolder>
637+
DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
638+
DenseElementsAttr rhs,
639+
RankedTensorType returnTy) {
640+
// arithmetic FloatFolder() functions return APFloat values
641+
return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
642+
}
643+
593644
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
594645
if (llvm::isa<FloatType>(elemType))
595646
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
636687
if (!lhsAttr || !rhsAttr)
637688
return {};
638689

639-
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
640-
resultTy);
690+
return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
691+
lhsAttr, rhsAttr, resultTy);
641692
}
642693

643694
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
693744
}
694745

695746
namespace {
747+
748+
// calculate lhs * rhs >> shift according to TOSA Spec
749+
// return nullopt if result is not in range of int32_t when shift > 0
750+
std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
751+
unsigned bitwidth) {
752+
APInt result = lhs.sext(64) * rhs.sext(64);
753+
754+
if (shift > 0) {
755+
auto round = APInt(64, 1) << (shift - 1);
756+
result += round;
757+
result.ashrInPlace(shift);
758+
// REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
759+
if (!(result.getSExtValue() >= INT32_MIN &&
760+
result.getSExtValue() <= INT32_MAX)) {
761+
// REQUIRE failed
762+
return std::nullopt;
763+
}
764+
}
765+
766+
return result.trunc(bitwidth);
767+
}
768+
696769
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
697770
RankedTensorType ty, int32_t shift) {
698-
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
699-
if (llvm::isa<IntegerType>(ty.getElementType())) {
700-
APInt l = lhs.getSplatValue<APInt>();
701-
APInt r = rhs.getSplatValue<APInt>();
771+
if (!lhs || !rhs)
772+
return {};
773+
774+
// REQUIRE(0 <= shift && shift <= 63);
775+
if (!(0 <= shift && shift <= 63))
776+
return {};
702777

703-
if (shift == 0) {
704-
return DenseElementsAttr::get(ty, l * r);
778+
auto elementType = ty.getElementType();
779+
if (!elementType.isIntOrFloat())
780+
return {};
781+
782+
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
783+
// REQUIRE(in_t == int32_t || shift == 0);
784+
if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
785+
return {};
786+
787+
if (rhs.isSplat() && lhs.isSplat()) {
788+
if (llvm::isa<IntegerType>(elementType)) {
789+
auto l = lhs.getSplatValue<APInt>();
790+
auto r = rhs.getSplatValue<APInt>();
791+
792+
if (auto result = mulInt(l, r, shift, bitwidth)) {
793+
return DenseElementsAttr::get(ty, result.value());
705794
}
795+
// mulInt failed
796+
return {};
797+
}
706798

707-
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
708-
l = l.sext(bitwidth * 2);
709-
r = r.sext(bitwidth * 2);
799+
if (llvm::isa<FloatType>(elementType)) {
800+
auto l = lhs.getSplatValue<APFloat>();
801+
auto r = rhs.getSplatValue<APFloat>();
710802
auto result = l * r;
711-
result.lshrInPlace(shift);
712-
result = result.trunc(bitwidth);
713803
return DenseElementsAttr::get(ty, result);
714804
}
805+
}
806+
807+
if (llvm::isa<IntegerType>(elementType)) {
808+
auto lvalues = lhs.getValues<APInt>();
809+
auto rvalues = rhs.getValues<APInt>();
810+
if (lvalues.size() != rvalues.size()) {
811+
return {};
812+
}
813+
SmallVector<APInt> results;
814+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
815+
if (auto result = mulInt(l, r, shift, bitwidth)) {
816+
results.push_back(result.value());
817+
continue;
818+
}
819+
// mulInt failed
820+
return {};
821+
}
822+
return DenseElementsAttr::get(ty, results);
823+
}
715824

716-
if (llvm::isa<FloatType>(ty.getElementType())) {
717-
APFloat l = lhs.getSplatValue<APFloat>();
718-
APFloat r = rhs.getSplatValue<APFloat>();
719-
APFloat result = l * r;
720-
return DenseElementsAttr::get(ty, result);
825+
if (llvm::isa<FloatType>(elementType)) {
826+
auto lvalues = lhs.getValues<APFloat>();
827+
auto rvalues = rhs.getValues<APFloat>();
828+
if (lvalues.size() != rvalues.size()) {
829+
return {};
721830
}
831+
SmallVector<APFloat> results;
832+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
833+
auto result = l * r;
834+
results.push_back(result);
835+
}
836+
return DenseElementsAttr::get(ty, results);
722837
}
723838

724839
return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
793908
if (!lhsAttr || !rhsAttr)
794909
return {};
795910

796-
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
797-
resultTy);
911+
return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
912+
lhsAttr, rhsAttr, resultTy);
798913
}
799914

800915
namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
835950
if (!lhsAttr || !rhsAttr)
836951
return {};
837952

838-
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
953+
return comparisonBinaryFolder<APIntFoldGreater,
954+
ComparisonFold<std::greater<APFloat>>>(
839955
lhsAttr, rhsAttr, resultTy);
840956
}
841957

@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
849965
if (!lhsAttr || !rhsAttr)
850966
return {};
851967

852-
return binaryFolder<APIntFoldGreaterEqual,
853-
ComparisonFold<std::greater_equal<APFloat>>>(
968+
return comparisonBinaryFolder<APIntFoldGreaterEqual,
969+
ComparisonFold<std::greater_equal<APFloat>>>(
854970
lhsAttr, rhsAttr, resultTy);
855971
}
856972

@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
874990
if (!lhsAttr || !rhsAttr)
875991
return {};
876992

877-
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
878-
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
879-
resultTy);
993+
return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
994+
ComparisonFold<std::equal_to<APFloat>>>(
995+
lhsAttr, rhsAttr, resultTy);
880996
}
881997

882998
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,11 +1092,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
10921092

10931093
func.func @reduce_sum_constant() -> tensor<1x3xi32> {
10941094
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
1095-
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1096-
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1097-
// CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
1098-
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
1099-
// CHECK: return %[[VAL_3]] : tensor<1x3xi32>
1095+
// CHECK: %[[K:.*]] = "tosa.const"() <{value = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
1096+
// CHECK: return %[[K]] : tensor<1x3xi32>
11001097
%arg0 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
11011098
%arg1 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
11021099
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

0 commit comments

Comments
 (0)