@@ -548,15 +548,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
548
548
// Operator Folders.
549
549
// ===----------------------------------------------------------------------===//
550
550
551
- template <typename IntFolder, typename FloatFolder>
551
+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
552
552
DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
553
553
RankedTensorType returnTy) {
554
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
555
- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
556
- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
557
- if (lETy != rETy)
558
- return {};
554
+ if (!rhs || !lhs)
555
+ return {};
556
+
557
+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
558
+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
559
+ if (lETy != rETy)
560
+ return {};
561
+
562
+ if (!lETy.isIntOrFloat ())
563
+ return {};
559
564
565
+ if (rhs.isSplat () && lhs.isSplat ()) {
560
566
if (llvm::isa<IntegerType>(lETy)) {
561
567
APInt l = lhs.getSplatValue <APInt>();
562
568
APInt r = rhs.getSplatValue <APInt>();
@@ -572,9 +578,60 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
572
578
}
573
579
}
574
580
581
+ auto lhsCount = lhs.getNumElements ();
582
+ auto rhsCount = rhs.getNumElements ();
583
+ if (lhsCount != rhsCount) {
584
+ return {};
585
+ }
586
+
587
+ const int64_t MAX_ELEMENT_COUNT = 128 ;
588
+ if (lhsCount > MAX_ELEMENT_COUNT) {
589
+ // to prevent long compile time, skip if too many elements
590
+ return {};
591
+ }
592
+
593
+ if (llvm::isa<IntegerType>(lETy)) {
594
+ auto lvalues = lhs.getValues <APInt>();
595
+ auto rvalues = rhs.getValues <APInt>();
596
+ SmallVector<APInt> results;
597
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
598
+ auto result = IntFolder ()(l, r);
599
+ results.push_back (result);
600
+ }
601
+ return DenseElementsAttr::get (returnTy, results);
602
+ }
603
+
604
+ if (llvm::isa<FloatType>(lETy)) {
605
+ auto lvalues = lhs.getValues <APFloat>();
606
+ auto rvalues = rhs.getValues <APFloat>();
607
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
608
+ SmallVector<FloatResultAPType> results;
609
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
610
+ auto result = FloatFolder ()(l, r);
611
+ results.push_back (result);
612
+ }
613
+ return DenseElementsAttr::get (returnTy, results);
614
+ }
615
+
575
616
return {};
576
617
}
577
618
619
+ template <typename IntFolder, typename FloatFolder>
620
+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
621
+ DenseElementsAttr rhs,
622
+ RankedTensorType returnTy) {
623
+ // comparison FloatFolder() functions return APInt values
624
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
625
+ }
626
+
627
+ template <typename IntFolder, typename FloatFolder>
628
+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
629
+ DenseElementsAttr rhs,
630
+ RankedTensorType returnTy) {
631
+ // arithmetic FloatFolder() functions return APFloat values
632
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
633
+ }
634
+
578
635
static bool isSplatZero (Type elemType, DenseElementsAttr val) {
579
636
if (llvm::isa<FloatType>(elemType))
580
637
return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -621,8 +678,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
621
678
if (!lhsAttr || !rhsAttr)
622
679
return {};
623
680
624
- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
625
- resultTy);
681
+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
682
+ lhsAttr, rhsAttr, resultTy);
626
683
}
627
684
628
685
OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -679,32 +736,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
679
736
}
680
737
681
738
namespace {
739
+
740
+ // calculate lhs * rhs >> shift according to TOSA Spec
741
+ // return nullopt if result is not in range of int32_t when shift > 0
742
+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
743
+ unsigned bitwidth) {
744
+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
745
+
746
+ if (shift > 0 ) {
747
+ auto round = APInt (64 , 1 ) << (shift - 1 );
748
+ result += round;
749
+ result.ashrInPlace (shift);
750
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
751
+ if (!(result.getSExtValue () >= INT32_MIN &&
752
+ result.getSExtValue () <= INT32_MAX)) {
753
+ // REQUIRE failed
754
+ return std::nullopt;
755
+ }
756
+ }
757
+
758
+ return result.trunc (bitwidth);
759
+ }
760
+
682
761
DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
683
762
RankedTensorType ty, int32_t shift) {
684
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
685
- if (llvm::isa<IntegerType>(ty.getElementType ())) {
686
- APInt l = lhs.getSplatValue <APInt>();
687
- APInt r = rhs.getSplatValue <APInt>();
763
+ if (!lhs || !rhs)
764
+ return {};
765
+
766
+ // REQUIRE(0 <= shift && shift <= 63);
767
+ if (!(0 <= shift && shift <= 63 ))
768
+ return {};
688
769
689
- if (shift == 0 ) {
690
- return DenseElementsAttr::get (ty, l * r);
770
+ auto elementType = ty.getElementType ();
771
+ if (!elementType.isIntOrFloat ())
772
+ return {};
773
+
774
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
775
+ // REQUIRE(in_t == int32_t || shift == 0);
776
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
777
+ return {};
778
+
779
+ if (rhs.isSplat () && lhs.isSplat ()) {
780
+ if (llvm::isa<IntegerType>(elementType)) {
781
+ auto l = lhs.getSplatValue <APInt>();
782
+ auto r = rhs.getSplatValue <APInt>();
783
+
784
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
785
+ return DenseElementsAttr::get (ty, result.value ());
691
786
}
787
+ // mulInt failed
788
+ return {};
789
+ }
692
790
693
- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
694
- l = l. sext (bitwidth * 2 );
695
- r = r. sext (bitwidth * 2 );
791
+ if (llvm::isa<FloatType>(elementType)) {
792
+ auto l = lhs. getSplatValue <APFloat>( );
793
+ auto r = rhs. getSplatValue <APFloat>( );
696
794
auto result = l * r;
697
- result.lshrInPlace (shift);
698
- result = result.trunc (bitwidth);
699
795
return DenseElementsAttr::get (ty, result);
700
796
}
797
+ }
701
798
702
- if (llvm::isa<FloatType>(ty. getElementType () )) {
703
- APFloat l = lhs.getSplatValue <APFloat >();
704
- APFloat r = rhs.getSplatValue <APFloat >();
705
- APFloat result = l * r;
706
- return DenseElementsAttr::get (ty, result) ;
799
+ if (llvm::isa<IntegerType>(elementType )) {
800
+ auto lvalues = lhs.getValues <APInt >();
801
+ auto rvalues = rhs.getValues <APInt >();
802
+ if (lvalues. size () != rvalues. size ()) {
803
+ return {} ;
707
804
}
805
+ SmallVector<APInt> results;
806
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
807
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
808
+ results.push_back (result.value ());
809
+ continue ;
810
+ }
811
+ // mulInt failed
812
+ return {};
813
+ }
814
+ return DenseElementsAttr::get (ty, results);
815
+ }
816
+
817
+ if (llvm::isa<FloatType>(elementType)) {
818
+ auto lvalues = lhs.getValues <APFloat>();
819
+ auto rvalues = rhs.getValues <APFloat>();
820
+ if (lvalues.size () != rvalues.size ()) {
821
+ return {};
822
+ }
823
+ SmallVector<APFloat> results;
824
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
825
+ auto result = l * r;
826
+ results.push_back (result);
827
+ }
828
+ return DenseElementsAttr::get (ty, results);
708
829
}
709
830
710
831
return {};
@@ -779,8 +900,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
779
900
if (!lhsAttr || !rhsAttr)
780
901
return {};
781
902
782
- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
783
- resultTy);
903
+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
904
+ lhsAttr, rhsAttr, resultTy);
784
905
}
785
906
786
907
namespace {
@@ -821,7 +942,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
821
942
if (!lhsAttr || !rhsAttr)
822
943
return {};
823
944
824
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
945
+ return comparisonBinaryFolder<APIntFoldGreater,
946
+ ComparisonFold<std::greater<APFloat>>>(
825
947
lhsAttr, rhsAttr, resultTy);
826
948
}
827
949
@@ -835,8 +957,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
835
957
if (!lhsAttr || !rhsAttr)
836
958
return {};
837
959
838
- return binaryFolder <APIntFoldGreaterEqual,
839
- ComparisonFold<std::greater_equal<APFloat>>>(
960
+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
961
+ ComparisonFold<std::greater_equal<APFloat>>>(
840
962
lhsAttr, rhsAttr, resultTy);
841
963
}
842
964
@@ -860,9 +982,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
860
982
if (!lhsAttr || !rhsAttr)
861
983
return {};
862
984
863
- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
864
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
865
- resultTy);
985
+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
986
+ ComparisonFold<std::equal_to<APFloat>>>(
987
+ lhsAttr, rhsAttr, resultTy);
866
988
}
867
989
868
990
OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments