@@ -501,15 +501,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
501
501
// Operator Folders.
502
502
// ===----------------------------------------------------------------------===//
503
503
504
- template <typename IntFolder, typename FloatFolder>
504
+ template <typename IntFolder, typename FloatFolder, typename FloatResultAPType >
505
505
DenseElementsAttr binaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
506
506
RankedTensorType returnTy) {
507
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
508
- auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
509
- auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
510
- if (lETy != rETy)
511
- return {};
507
+ if (!rhs || !lhs)
508
+ return {};
509
+
510
+ auto lETy = llvm::cast<ShapedType>(lhs.getType ()).getElementType ();
511
+ auto rETy = llvm::cast<ShapedType>(rhs.getType ()).getElementType ();
512
+ if (lETy != rETy)
513
+ return {};
514
+
515
+ if (!lETy.isIntOrFloat ())
516
+ return {};
512
517
518
+ if (rhs.isSplat () && lhs.isSplat ()) {
513
519
if (llvm::isa<IntegerType>(lETy)) {
514
520
APInt l = lhs.getSplatValue <APInt>();
515
521
APInt r = rhs.getSplatValue <APInt>();
@@ -525,9 +531,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
525
531
}
526
532
}
527
533
534
+ auto lhsCount = lhs.getNumElements ();
535
+ auto rhsCount = rhs.getNumElements ();
536
+ if (lhsCount != rhsCount)
537
+ return {};
538
+
539
+ // to prevent long compile time, skip if too many elements
540
+ if (lhsCount > 128 )
541
+ return {};
542
+
543
+ if (llvm::isa<IntegerType>(lETy)) {
544
+ auto lvalues = lhs.getValues <APInt>();
545
+ auto rvalues = rhs.getValues <APInt>();
546
+ SmallVector<APInt> results;
547
+ IntFolder intFolder{};
548
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
549
+ auto result = intFolder (l, r);
550
+ results.push_back (result);
551
+ }
552
+ return DenseElementsAttr::get (returnTy, results);
553
+ }
554
+
555
+ if (llvm::isa<FloatType>(lETy)) {
556
+ auto lvalues = lhs.getValues <APFloat>();
557
+ auto rvalues = rhs.getValues <APFloat>();
558
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
559
+ SmallVector<FloatResultAPType> results;
560
+ FloatFolder floatFolder{};
561
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
562
+ auto result = floatFolder (l, r);
563
+ results.push_back (result);
564
+ }
565
+ return DenseElementsAttr::get (returnTy, results);
566
+ }
567
+
528
568
return {};
529
569
}
530
570
571
+ template <typename IntFolder, typename FloatFolder>
572
+ DenseElementsAttr comparisonBinaryFolder (DenseElementsAttr lhs,
573
+ DenseElementsAttr rhs,
574
+ RankedTensorType returnTy) {
575
+ // comparison FloatFolder() functions return APInt values
576
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
577
+ }
578
+
579
+ template <typename IntFolder, typename FloatFolder>
580
+ DenseElementsAttr arithmeticBinaryFolder (DenseElementsAttr lhs,
581
+ DenseElementsAttr rhs,
582
+ RankedTensorType returnTy) {
583
+ // arithmetic FloatFolder() functions return APFloat values
584
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
585
+ }
586
+
531
587
static bool isSplatZero (Type elemType, DenseElementsAttr val) {
532
588
if (llvm::isa<FloatType>(elemType))
533
589
return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
@@ -574,8 +630,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
574
630
if (!lhsAttr || !rhsAttr)
575
631
return {};
576
632
577
- return binaryFolder <std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
578
- resultTy);
633
+ return arithmeticBinaryFolder <std::plus<APInt>, std::plus<APFloat>>(
634
+ lhsAttr, rhsAttr, resultTy);
579
635
}
580
636
581
637
OpFoldResult ArgMaxOp::fold (FoldAdaptor adaptor) {
@@ -632,32 +688,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
632
688
}
633
689
634
690
namespace {
691
+
692
+ // calculate lhs * rhs >> shift according to TOSA Spec
693
+ // return nullopt if result is not in range of int32_t when shift > 0
694
+ std::optional<APInt> mulInt (APInt lhs, APInt rhs, int32_t shift,
695
+ unsigned bitwidth) {
696
+ APInt result = lhs.sext (64 ) * rhs.sext (64 );
697
+
698
+ if (shift > 0 ) {
699
+ auto round = APInt (64 , 1 ) << (shift - 1 );
700
+ result += round;
701
+ result.ashrInPlace (shift);
702
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
703
+ if (!(result.getSExtValue () >= INT32_MIN &&
704
+ result.getSExtValue () <= INT32_MAX)) {
705
+ // REQUIRE failed
706
+ return std::nullopt;
707
+ }
708
+ }
709
+
710
+ return result.trunc (bitwidth);
711
+ }
712
+
635
713
DenseElementsAttr mulBinaryFolder (DenseElementsAttr lhs, DenseElementsAttr rhs,
636
714
RankedTensorType ty, int32_t shift) {
637
- if (rhs && lhs && rhs.isSplat () && lhs.isSplat ()) {
638
- if (llvm::isa<IntegerType>(ty.getElementType ())) {
639
- APInt l = lhs.getSplatValue <APInt>();
640
- APInt r = rhs.getSplatValue <APInt>();
715
+ if (!lhs || !rhs)
716
+ return {};
717
+
718
+ // REQUIRE(0 <= shift && shift <= 63);
719
+ if (!(0 <= shift && shift <= 63 ))
720
+ return {};
721
+
722
+ auto elementType = ty.getElementType ();
723
+ if (!elementType.isIntOrFloat ())
724
+ return {};
641
725
642
- if (shift == 0 ) {
643
- return DenseElementsAttr::get (ty, l * r);
726
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth ();
727
+ // REQUIRE(in_t == int32_t || shift == 0);
728
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32 ) || shift == 0 ))
729
+ return {};
730
+
731
+ if (rhs.isSplat () && lhs.isSplat ()) {
732
+ if (llvm::isa<IntegerType>(elementType)) {
733
+ auto l = lhs.getSplatValue <APInt>();
734
+ auto r = rhs.getSplatValue <APInt>();
735
+
736
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
737
+ return DenseElementsAttr::get (ty, result.value ());
644
738
}
739
+ // mulInt failed
740
+ return {};
741
+ }
645
742
646
- auto bitwidth = ty. getElementType (). getIntOrFloatBitWidth ();
647
- l = l. sext (bitwidth * 2 );
648
- r = r. sext (bitwidth * 2 );
743
+ if (llvm::isa<FloatType>(elementType)) {
744
+ auto l = lhs. getSplatValue <APFloat>( );
745
+ auto r = rhs. getSplatValue <APFloat>( );
649
746
auto result = l * r;
650
- result.lshrInPlace (shift);
651
- result = result.trunc (bitwidth);
652
747
return DenseElementsAttr::get (ty, result);
653
748
}
749
+ }
654
750
655
- if (llvm::isa<FloatType>(ty.getElementType ())) {
656
- APFloat l = lhs.getSplatValue <APFloat>();
657
- APFloat r = rhs.getSplatValue <APFloat>();
658
- APFloat result = l * r;
659
- return DenseElementsAttr::get (ty, result);
751
+ if (llvm::isa<IntegerType>(elementType)) {
752
+ auto lvalues = lhs.getValues <APInt>();
753
+ auto rvalues = rhs.getValues <APInt>();
754
+ if (lvalues.size () != rvalues.size ()) {
755
+ return {};
756
+ }
757
+ SmallVector<APInt> results;
758
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
759
+ if (auto result = mulInt (l, r, shift, bitwidth)) {
760
+ results.push_back (result.value ());
761
+ continue ;
762
+ }
763
+ // mulInt failed
764
+ return {};
765
+ }
766
+ return DenseElementsAttr::get (ty, results);
767
+ }
768
+
769
+ if (llvm::isa<FloatType>(elementType)) {
770
+ auto lvalues = lhs.getValues <APFloat>();
771
+ auto rvalues = rhs.getValues <APFloat>();
772
+ if (lvalues.size () != rvalues.size ()) {
773
+ return {};
660
774
}
775
+ SmallVector<APFloat> results;
776
+ for (const auto &[l, r] : llvm::zip (lvalues, rvalues)) {
777
+ auto result = l * r;
778
+ results.push_back (result);
779
+ }
780
+ return DenseElementsAttr::get (ty, results);
661
781
}
662
782
663
783
return {};
@@ -732,8 +852,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
732
852
if (!lhsAttr || !rhsAttr)
733
853
return {};
734
854
735
- return binaryFolder <std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
736
- resultTy);
855
+ return arithmeticBinaryFolder <std::minus<APInt>, std::minus<APFloat>>(
856
+ lhsAttr, rhsAttr, resultTy);
737
857
}
738
858
739
859
namespace {
@@ -774,7 +894,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
774
894
if (!lhsAttr || !rhsAttr)
775
895
return {};
776
896
777
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
897
+ return comparisonBinaryFolder<APIntFoldGreater,
898
+ ComparisonFold<std::greater<APFloat>>>(
778
899
lhsAttr, rhsAttr, resultTy);
779
900
}
780
901
@@ -788,8 +909,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
788
909
if (!lhsAttr || !rhsAttr)
789
910
return {};
790
911
791
- return binaryFolder <APIntFoldGreaterEqual,
792
- ComparisonFold<std::greater_equal<APFloat>>>(
912
+ return comparisonBinaryFolder <APIntFoldGreaterEqual,
913
+ ComparisonFold<std::greater_equal<APFloat>>>(
793
914
lhsAttr, rhsAttr, resultTy);
794
915
}
795
916
@@ -813,9 +934,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
813
934
if (!lhsAttr || !rhsAttr)
814
935
return {};
815
936
816
- return binaryFolder <ComparisonFold<std::equal_to<APInt>>,
817
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
818
- resultTy);
937
+ return comparisonBinaryFolder <ComparisonFold<std::equal_to<APInt>>,
938
+ ComparisonFold<std::equal_to<APFloat>>>(
939
+ lhsAttr, rhsAttr, resultTy);
819
940
}
820
941
821
942
OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
0 commit comments