@@ -548,15 +548,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
548
548
// Operator Folders.
549
549
// ===----------------------------------------------------------------------===//
550
550
551
- template <typename IntFolder, typename FloatFolder>
551
+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
552
552
DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
553
553
RankedTensorType returnTy) {
554
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
555
- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
556
- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
557
- if (lETy != rETy)
558
- return {};
554
+ if (!rhs || !lhs)
555
+ return {};
556
+
557
+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
558
+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
559
+ if (lETy != rETy)
560
+ return {};
561
+
562
+ if (!lETy.isIntOrFloat ())
563
+ return {};
559
564
565
+ if (rhs.isSplat () && lhs.isSplat ()) {
560
566
if (llvm::isa<IntegerType>(lETy)) {
561
567
APInt l = lhs.getSplatValue <APInt>();
562
568
APInt r = rhs.getSplatValue <APInt>();
@@ -572,9 +578,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
572
578
}
573
579
}
574
580
581
+ auto lhsCount = lhs.getNumElements ();
582
+ auto rhsCount = rhs.getNumElements ();
583
+ if (lhsCount != rhsCount)
584
+ return {};
585
+
586
+ // to prevent long compile time, skip if too many elements
587
+ if (lhsCount > 128 )
588
+ return {};
589
+
590
+ if (llvm::isa<IntegerType>(lETy)) {
591
+ auto lvalues = lhs.getValues <APInt>();
592
+ auto rvalues = rhs.getValues <APInt>();
593
+ SmallVector<APInt> results;
594
+ IntFolder intFolder{};
595
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
596
+ auto result = intFolder (l, r);
597
+ results.push_back (result);
598
+ }
599
+ return DenseElementsAttr::get (returnTy, results);
600
+ }
601
+
602
+ if (llvm::isa<FloatType>(lETy)) {
603
+ auto lvalues = lhs.getValues <APFloat>();
604
+ auto rvalues = rhs.getValues <APFloat>();
605
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
606
+ SmallVector<FloatResultAPType> results;
607
+ FloatFolder floatFolder{};
608
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
609
+ auto result = floatFolder (l, r);
610
+ results.push_back (result);
611
+ }
612
+ return DenseElementsAttr::get (returnTy, results);
613
+ }
614
+
575
615
return {};
576
616
}
577
617
618
+ template <typename IntFolder, typename FloatFolder>
619
+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
620
+ DenseElementsAttr rhs,
621
+ RankedTensorType returnTy) {
622
+ // comparison FloatFolder() functions return APInt values
623
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
624
+ }
625
+
626
+ template <typename IntFolder, typename FloatFolder>
627
+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
628
+ DenseElementsAttr rhs,
629
+ RankedTensorType returnTy) {
630
+ // arithmetic FloatFolder() functions return APFloat values
631
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
632
+ }
633
+
578
634
static bool isSplatZero (Type elemType, DenseElementsAttr val) {
579
635
if (llvm::isa<FloatType>(elemType))
580
636
return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -621,8 +677,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
621
677
if (!lhsAttr || !rhsAttr)
622
678
return {};
623
679
624
- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
625
- resultTy);
680
+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
681
+ lhsAttr, rhsAttr, resultTy);
626
682
}
627
683
628
684
OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -679,32 +735,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
679
735
}
680
736
681
737
namespace {
738
+
739
+ // calculate lhs * rhs >> shift according to TOSA Spec
740
+ // return nullopt if result is not in range of int32_t when shift > 0
741
+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
742
+ unsigned bitwidth) {
743
+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
744
+
745
+ if (shift > 0 ) {
746
+ auto round = APInt (64 , 1 ) << (shift - 1 );
747
+ result += round;
748
+ result.ashrInPlace (shift);
749
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
750
+ if (!(result.getSExtValue () >= INT32_MIN &&
751
+ result.getSExtValue () <= INT32_MAX)) {
752
+ // REQUIRE failed
753
+ return std::nullopt;
754
+ }
755
+ }
756
+
757
+ return result.trunc (bitwidth);
758
+ }
759
+
682
760
DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
683
761
RankedTensorType ty, int32_t shift) {
684
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
685
- if (llvm::isa<IntegerType>(ty.getElementType ())) {
686
- APInt l = lhs.getSplatValue <APInt>();
687
- APInt r = rhs.getSplatValue <APInt>();
762
+ if (!lhs || !rhs)
763
+ return {};
764
+
765
+ // REQUIRE(0 <= shift && shift <= 63);
766
+ if (!(0 <= shift && shift <= 63 ))
767
+ return {};
768
+
769
+ auto elementType = ty.getElementType ();
770
+ if (!elementType.isIntOrFloat ())
771
+ return {};
688
772
689
- if (shift == 0 ) {
690
- return DenseElementsAttr::get (ty, l * r);
773
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
774
+ // REQUIRE(in_t == int32_t || shift == 0);
775
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
776
+ return {};
777
+
778
+ if (rhs.isSplat () && lhs.isSplat ()) {
779
+ if (llvm::isa<IntegerType>(elementType)) {
780
+ auto l = lhs.getSplatValue <APInt>();
781
+ auto r = rhs.getSplatValue <APInt>();
782
+
783
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
784
+ return DenseElementsAttr::get (ty, result.value ());
691
785
}
786
+ // mulInt failed
787
+ return {};
788
+ }
692
789
693
- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
694
- l = l. sext (bitwidth * 2 );
695
- r = r. sext (bitwidth * 2 );
790
+ if (llvm::isa<FloatType>(elementType)) {
791
+ auto l = lhs. getSplatValue <APFloat>( );
792
+ auto r = rhs. getSplatValue <APFloat>( );
696
793
auto result = l * r;
697
- result.lshrInPlace (shift);
698
- result = result.trunc (bitwidth);
699
794
return DenseElementsAttr::get (ty, result);
700
795
}
796
+ }
701
797
702
- if (llvm::isa<FloatType>(ty.getElementType ())) {
703
- APFloat l = lhs.getSplatValue <APFloat>();
704
- APFloat r = rhs.getSplatValue <APFloat>();
705
- APFloat result = l * r;
706
- return DenseElementsAttr::get (ty, result);
798
+ if (llvm::isa<IntegerType>(elementType)) {
799
+ auto lvalues = lhs.getValues <APInt>();
800
+ auto rvalues = rhs.getValues <APInt>();
801
+ if (lvalues.size () != rvalues.size ()) {
802
+ return {};
803
+ }
804
+ SmallVector<APInt> results;
805
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
806
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
807
+ results.push_back (result.value ());
808
+ continue ;
809
+ }
810
+ // mulInt failed
811
+ return {};
812
+ }
813
+ return DenseElementsAttr::get (ty, results);
814
+ }
815
+
816
+ if (llvm::isa<FloatType>(elementType)) {
817
+ auto lvalues = lhs.getValues <APFloat>();
818
+ auto rvalues = rhs.getValues <APFloat>();
819
+ if (lvalues.size () != rvalues.size ()) {
820
+ return {};
707
821
}
822
+ SmallVector<APFloat> results;
823
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
824
+ auto result = l * r;
825
+ results.push_back (result);
826
+ }
827
+ return DenseElementsAttr::get (ty, results);
708
828
}
709
829
710
830
return {};
@@ -779,8 +899,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
779
899
if (!lhsAttr || !rhsAttr)
780
900
return {};
781
901
782
- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
783
- resultTy);
902
+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
903
+ lhsAttr, rhsAttr, resultTy);
784
904
}
785
905
786
906
namespace {
@@ -821,7 +941,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
821
941
if (!lhsAttr || !rhsAttr)
822
942
return {};
823
943
824
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
944
+ return comparisonBinaryFolder<APIntFoldGreater,
945
+ ComparisonFold<std::greater<APFloat>>>(
825
946
lhsAttr, rhsAttr, resultTy);
826
947
}
827
948
@@ -835,8 +956,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
835
956
if (!lhsAttr || !rhsAttr)
836
957
return {};
837
958
838
- return binaryFolder <APIntFoldGreaterEqual,
839
- ComparisonFold<std::greater_equal<APFloat>>>(
959
+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
960
+ ComparisonFold<std::greater_equal<APFloat>>>(
840
961
lhsAttr, rhsAttr, resultTy);
841
962
}
842
963
@@ -860,9 +981,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
860
981
if (!lhsAttr || !rhsAttr)
861
982
return {};
862
983
863
- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
864
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
865
- resultTy);
984
+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
985
+ ComparisonFold<std::equal_to<APFloat>>>(
986
+ lhsAttr, rhsAttr, resultTy);
866
987
}
867
988
868
989
OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments