@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
563
563
// Operator Folders.
564
564
// ===----------------------------------------------------------------------===//
565
565
566
- template <typename IntFolder, typename FloatFolder>
566
+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
567
567
DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
568
568
RankedTensorType returnTy) {
569
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
570
- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
571
- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
572
- if (lETy != rETy)
573
- return {};
569
+ if (!rhs || !lhs)
570
+ return {};
571
+
572
+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
573
+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
574
+ if (lETy != rETy)
575
+ return {};
576
+
577
+ if (!lETy.isIntOrFloat ())
578
+ return {};
574
579
580
+ if (rhs.isSplat () && lhs.isSplat ()) {
575
581
if (llvm::isa<IntegerType>(lETy)) {
576
582
APInt l = lhs.getSplatValue <APInt>();
577
583
APInt r = rhs.getSplatValue <APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
587
593
}
588
594
}
589
595
596
+ if (llvm::isa<IntegerType>(lETy)) {
597
+ auto lvalues = lhs.getValues <APInt>();
598
+ auto rvalues = rhs.getValues <APInt>();
599
+ if (lvalues.size () != rvalues.size ()) {
600
+ return {};
601
+ }
602
+ SmallVector<APInt> results;
603
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
604
+ auto result = IntFolder ()(l, r);
605
+ results.push_back (result);
606
+ }
607
+ return DenseElementsAttr::get (returnTy, results);
608
+ }
609
+
610
+ if (llvm::isa<FloatType>(lETy)) {
611
+ auto lvalues = lhs.getValues <APFloat>();
612
+ auto rvalues = rhs.getValues <APFloat>();
613
+ if (lvalues.size () != rvalues.size ()) {
614
+ return {};
615
+ }
616
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
617
+ SmallVector<FloatResultAPType> results;
618
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
619
+ auto result = FloatFolder ()(l, r);
620
+ results.push_back (result);
621
+ }
622
+ return DenseElementsAttr::get (returnTy, results);
623
+ }
624
+
590
625
return {};
591
626
}
592
627
628
+ template <typename IntFolder, typename FloatFolder>
629
+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
630
+ DenseElementsAttr rhs,
631
+ RankedTensorType returnTy) {
632
+ // comparison FloatFolder() functions return APInt values
633
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
634
+ }
635
+
636
+ template <typename IntFolder, typename FloatFolder>
637
+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
638
+ DenseElementsAttr rhs,
639
+ RankedTensorType returnTy) {
640
+ // arithmetic FloatFolder() functions return APFloat values
641
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
642
+ }
643
+
593
644
static bool isSplatZero (Type elemType, DenseElementsAttr val) {
594
645
if (llvm::isa<FloatType>(elemType))
595
646
return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
636
687
if (!lhsAttr || !rhsAttr)
637
688
return {};
638
689
639
- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
640
- resultTy);
690
+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
691
+ lhsAttr, rhsAttr, resultTy);
641
692
}
642
693
643
694
OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
693
744
}
694
745
695
746
namespace {
747
+
748
+ // calculate lhs * rhs >> shift according to TOSA Spec
749
+ // return nullopt if result is not in range of int32_t when shift > 0
750
+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
751
+ unsigned bitwidth) {
752
+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
753
+
754
+ if (shift > 0 ) {
755
+ auto round = APInt (64 , 1 ) << (shift - 1 );
756
+ result += round;
757
+ result.ashrInPlace (shift);
758
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
759
+ if (!(result.getSExtValue () >= INT32_MIN &&
760
+ result.getSExtValue () <= INT32_MAX)) {
761
+ // REQUIRE failed
762
+ return std::nullopt;
763
+ }
764
+ }
765
+
766
+ return result.trunc (bitwidth);
767
+ }
768
+
696
769
DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
697
770
RankedTensorType ty, int32_t shift) {
698
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
699
- if (llvm::isa<IntegerType>(ty.getElementType ())) {
700
- APInt l = lhs.getSplatValue <APInt>();
701
- APInt r = rhs.getSplatValue <APInt>();
771
+ if (!lhs || !rhs)
772
+ return {};
773
+
774
+ // REQUIRE(0 <= shift && shift <= 63);
775
+ if (!(0 <= shift && shift <= 63 ))
776
+ return {};
702
777
703
- if (shift == 0 ) {
704
- return DenseElementsAttr::get (ty, l * r);
778
+ auto elementType = ty.getElementType ();
779
+ if (!elementType.isIntOrFloat ())
780
+ return {};
781
+
782
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
783
+ // REQUIRE(in_t == int32_t || shift == 0);
784
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
785
+ return {};
786
+
787
+ if (rhs.isSplat () && lhs.isSplat ()) {
788
+ if (llvm::isa<IntegerType>(elementType)) {
789
+ auto l = lhs.getSplatValue <APInt>();
790
+ auto r = rhs.getSplatValue <APInt>();
791
+
792
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
793
+ return DenseElementsAttr::get (ty, result.value ());
705
794
}
795
+ // mulInt failed
796
+ return {};
797
+ }
706
798
707
- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
708
- l = l. sext (bitwidth * 2 );
709
- r = r. sext (bitwidth * 2 );
799
+ if (llvm::isa<FloatType>(elementType)) {
800
+ auto l = lhs. getSplatValue <APFloat>( );
801
+ auto r = rhs. getSplatValue <APFloat>( );
710
802
auto result = l * r;
711
- result.lshrInPlace (shift);
712
- result = result.trunc (bitwidth);
713
803
return DenseElementsAttr::get (ty, result);
714
804
}
805
+ }
806
+
807
+ if (llvm::isa<IntegerType>(elementType)) {
808
+ auto lvalues = lhs.getValues <APInt>();
809
+ auto rvalues = rhs.getValues <APInt>();
810
+ if (lvalues.size () != rvalues.size ()) {
811
+ return {};
812
+ }
813
+ SmallVector<APInt> results;
814
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
815
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
816
+ results.push_back (result.value ());
817
+ continue ;
818
+ }
819
+ // mulInt failed
820
+ return {};
821
+ }
822
+ return DenseElementsAttr::get (ty, results);
823
+ }
715
824
716
- if (llvm::isa<FloatType>(ty. getElementType () )) {
717
- APFloat l = lhs.getSplatValue <APFloat>();
718
- APFloat r = rhs.getSplatValue <APFloat>();
719
- APFloat result = l * r;
720
- return DenseElementsAttr::get (ty, result) ;
825
+ if (llvm::isa<FloatType>(elementType )) {
826
+ auto lvalues = lhs.getValues <APFloat>();
827
+ auto rvalues = rhs.getValues <APFloat>();
828
+ if (lvalues. size () != rvalues. size ()) {
829
+ return {} ;
721
830
}
831
+ SmallVector<APFloat> results;
832
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
833
+ auto result = l * r;
834
+ results.push_back (result);
835
+ }
836
+ return DenseElementsAttr::get (ty, results);
722
837
}
723
838
724
839
return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
793
908
if (!lhsAttr || !rhsAttr)
794
909
return {};
795
910
796
- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
797
- resultTy);
911
+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
912
+ lhsAttr, rhsAttr, resultTy);
798
913
}
799
914
800
915
namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
835
950
if (!lhsAttr || !rhsAttr)
836
951
return {};
837
952
838
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
953
+ return comparisonBinaryFolder<APIntFoldGreater,
954
+ ComparisonFold<std::greater<APFloat>>>(
839
955
lhsAttr, rhsAttr, resultTy);
840
956
}
841
957
@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
849
965
if (!lhsAttr || !rhsAttr)
850
966
return {};
851
967
852
- return binaryFolder <APIntFoldGreaterEqual,
853
- ComparisonFold<std::greater_equal<APFloat>>>(
968
+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
969
+ ComparisonFold<std::greater_equal<APFloat>>>(
854
970
lhsAttr, rhsAttr, resultTy);
855
971
}
856
972
@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
874
990
if (!lhsAttr || !rhsAttr)
875
991
return {};
876
992
877
- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
878
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
879
- resultTy);
993
+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
994
+ ComparisonFold<std::equal_to<APFloat>>>(
995
+ lhsAttr, rhsAttr, resultTy);
880
996
}
881
997
882
998
OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments