Skip to content

Commit b37535a

Browse files
committed
[mlir][tosa] Enhance folder for Tosa binary operators
This enhances folder for tosa binary operators to support non-splat constant attributes for following ops: - mul - add - sub - greater - greater_equal - equal
1 parent 6f25614 commit b37535a

File tree

3 files changed

+370
-40
lines changed

3 files changed

+370
-40
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 153 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -548,15 +548,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
548548
// Operator Folders.
549549
//===----------------------------------------------------------------------===//
550550

551-
template <typename IntFolder, typename FloatFolder>
551+
template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
552552
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
553553
RankedTensorType returnTy) {
554-
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
555-
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
556-
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
557-
if (lETy != rETy)
558-
return {};
554+
if (!rhs || !lhs)
555+
return {};
556+
557+
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
558+
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
559+
if (lETy != rETy)
560+
return {};
561+
562+
if (!lETy.isIntOrFloat())
563+
return {};
559564

565+
if (rhs.isSplat() && lhs.isSplat()) {
560566
if (llvm::isa<IntegerType>(lETy)) {
561567
APInt l = lhs.getSplatValue<APInt>();
562568
APInt r = rhs.getSplatValue<APInt>();
@@ -572,9 +578,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
572578
}
573579
}
574580

581+
auto lhsCount = lhs.getNumElements();
582+
auto rhsCount = rhs.getNumElements();
583+
if (lhsCount != rhsCount)
584+
return {};
585+
586+
// to prevent long compile time, skip if too many elements
587+
if (lhsCount > 128)
588+
return {};
589+
590+
if (llvm::isa<IntegerType>(lETy)) {
591+
auto lvalues = lhs.getValues<APInt>();
592+
auto rvalues = rhs.getValues<APInt>();
593+
SmallVector<APInt> results;
594+
IntFolder intFolder{};
595+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
596+
auto result = intFolder(l, r);
597+
results.push_back(result);
598+
}
599+
return DenseElementsAttr::get(returnTy, results);
600+
}
601+
602+
if (llvm::isa<FloatType>(lETy)) {
603+
auto lvalues = lhs.getValues<APFloat>();
604+
auto rvalues = rhs.getValues<APFloat>();
605+
// FloatFolder() may return either APFloat or APInt (comparison functions)
606+
SmallVector<FloatResultAPType> results;
607+
FloatFolder floatFolder{};
608+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
609+
auto result = floatFolder(l, r);
610+
results.push_back(result);
611+
}
612+
return DenseElementsAttr::get(returnTy, results);
613+
}
614+
575615
return {};
576616
}
577617

618+
template <typename IntFolder, typename FloatFolder>
619+
DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
620+
DenseElementsAttr rhs,
621+
RankedTensorType returnTy) {
622+
// comparison FloatFolder() functions return APInt values
623+
return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
624+
}
625+
626+
template <typename IntFolder, typename FloatFolder>
627+
DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
628+
DenseElementsAttr rhs,
629+
RankedTensorType returnTy) {
630+
// arithmetic FloatFolder() functions return APFloat values
631+
return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
632+
}
633+
578634
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
579635
if (llvm::isa<FloatType>(elemType))
580636
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -621,8 +677,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
621677
if (!lhsAttr || !rhsAttr)
622678
return {};
623679

624-
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
625-
resultTy);
680+
return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
681+
lhsAttr, rhsAttr, resultTy);
626682
}
627683

628684
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -679,32 +735,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
679735
}
680736

681737
namespace {
738+
739+
// calculate lhs * rhs >> shift according to TOSA Spec
740+
// return nullopt if result is not in range of int32_t when shift > 0
741+
std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
742+
unsigned bitwidth) {
743+
APInt result = lhs.sext(64) * rhs.sext(64);
744+
745+
if (shift > 0) {
746+
auto round = APInt(64, 1) << (shift - 1);
747+
result += round;
748+
result.ashrInPlace(shift);
749+
// REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
750+
if (!(result.getSExtValue() >= INT32_MIN &&
751+
result.getSExtValue() <= INT32_MAX)) {
752+
// REQUIRE failed
753+
return std::nullopt;
754+
}
755+
}
756+
757+
return result.trunc(bitwidth);
758+
}
759+
682760
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
683761
RankedTensorType ty, int32_t shift) {
684-
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
685-
if (llvm::isa<IntegerType>(ty.getElementType())) {
686-
APInt l = lhs.getSplatValue<APInt>();
687-
APInt r = rhs.getSplatValue<APInt>();
762+
if (!lhs || !rhs)
763+
return {};
764+
765+
// REQUIRE(0 <= shift && shift <= 63);
766+
if (!(0 <= shift && shift <= 63))
767+
return {};
768+
769+
auto elementType = ty.getElementType();
770+
if (!elementType.isIntOrFloat())
771+
return {};
688772

689-
if (shift == 0) {
690-
return DenseElementsAttr::get(ty, l * r);
773+
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
774+
// REQUIRE(in_t == int32_t || shift == 0);
775+
if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
776+
return {};
777+
778+
if (rhs.isSplat() && lhs.isSplat()) {
779+
if (llvm::isa<IntegerType>(elementType)) {
780+
auto l = lhs.getSplatValue<APInt>();
781+
auto r = rhs.getSplatValue<APInt>();
782+
783+
if (auto result = mulInt(l, r, shift, bitwidth)) {
784+
return DenseElementsAttr::get(ty, result.value());
691785
}
786+
// mulInt failed
787+
return {};
788+
}
692789

693-
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
694-
l = l.sext(bitwidth * 2);
695-
r = r.sext(bitwidth * 2);
790+
if (llvm::isa<FloatType>(elementType)) {
791+
auto l = lhs.getSplatValue<APFloat>();
792+
auto r = rhs.getSplatValue<APFloat>();
696793
auto result = l * r;
697-
result.lshrInPlace(shift);
698-
result = result.trunc(bitwidth);
699794
return DenseElementsAttr::get(ty, result);
700795
}
796+
}
701797

702-
if (llvm::isa<FloatType>(ty.getElementType())) {
703-
APFloat l = lhs.getSplatValue<APFloat>();
704-
APFloat r = rhs.getSplatValue<APFloat>();
705-
APFloat result = l * r;
706-
return DenseElementsAttr::get(ty, result);
798+
if (llvm::isa<IntegerType>(elementType)) {
799+
auto lvalues = lhs.getValues<APInt>();
800+
auto rvalues = rhs.getValues<APInt>();
801+
if (lvalues.size() != rvalues.size()) {
802+
return {};
803+
}
804+
SmallVector<APInt> results;
805+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
806+
if (auto result = mulInt(l, r, shift, bitwidth)) {
807+
results.push_back(result.value());
808+
continue;
809+
}
810+
// mulInt failed
811+
return {};
812+
}
813+
return DenseElementsAttr::get(ty, results);
814+
}
815+
816+
if (llvm::isa<FloatType>(elementType)) {
817+
auto lvalues = lhs.getValues<APFloat>();
818+
auto rvalues = rhs.getValues<APFloat>();
819+
if (lvalues.size() != rvalues.size()) {
820+
return {};
707821
}
822+
SmallVector<APFloat> results;
823+
for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
824+
auto result = l * r;
825+
results.push_back(result);
826+
}
827+
return DenseElementsAttr::get(ty, results);
708828
}
709829

710830
return {};
@@ -779,8 +899,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
779899
if (!lhsAttr || !rhsAttr)
780900
return {};
781901

782-
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
783-
resultTy);
902+
return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
903+
lhsAttr, rhsAttr, resultTy);
784904
}
785905

786906
namespace {
@@ -821,7 +941,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
821941
if (!lhsAttr || !rhsAttr)
822942
return {};
823943

824-
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
944+
return comparisonBinaryFolder<APIntFoldGreater,
945+
ComparisonFold<std::greater<APFloat>>>(
825946
lhsAttr, rhsAttr, resultTy);
826947
}
827948

@@ -835,8 +956,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
835956
if (!lhsAttr || !rhsAttr)
836957
return {};
837958

838-
return binaryFolder<APIntFoldGreaterEqual,
839-
ComparisonFold<std::greater_equal<APFloat>>>(
959+
return comparisonBinaryFolder<APIntFoldGreaterEqual,
960+
ComparisonFold<std::greater_equal<APFloat>>>(
840961
lhsAttr, rhsAttr, resultTy);
841962
}
842963

@@ -860,9 +981,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
860981
if (!lhsAttr || !rhsAttr)
861982
return {};
862983

863-
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
864-
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
865-
resultTy);
984+
return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
985+
ComparisonFold<std::equal_to<APFloat>>>(
986+
lhsAttr, rhsAttr, resultTy);
866987
}
867988

868989
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,11 +1082,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
10821082

10831083
func.func @reduce_sum_constant() -> tensor<1x3xi32> {
10841084
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
1085-
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1086-
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1087-
// CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
1088-
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
1089-
// CHECK: return %[[VAL_3]] : tensor<1x3xi32>
1085+
// CHECK: %[[K:.*]] = "tosa.const"() <{value = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
1086+
// CHECK: return %[[K]] : tensor<1x3xi32>
10901087
%arg0 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
10911088
%arg1 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
10921089
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

0 commit comments

Comments
 (0)