@@ -91,6 +91,10 @@ using namespace llvm::PatternMatch;
91
91
92
92
STATISTIC (NumRemoved, " Number of unreachable basic blocks removed" );
93
93
94
+ // Max recursion depth for collectBitParts used when detecting bswap and
95
+ // bitreverse idioms
96
+ static const unsigned BitPartRecursionMaxDepth = 64 ;
97
+
94
98
// ===----------------------------------------------------------------------===//
95
99
// Local constant propagation.
96
100
//
@@ -2619,21 +2623,27 @@ struct BitPart {
2619
2623
// / does not invalidate internal references (std::map instead of DenseMap).
2620
2624
static const Optional<BitPart> &
2621
2625
collectBitParts (Value *V, bool MatchBSwaps, bool MatchBitReversals,
2622
- std::map<Value *, Optional<BitPart>> &BPS) {
2626
+ std::map<Value *, Optional<BitPart>> &BPS, int Depth ) {
2623
2627
auto I = BPS.find (V);
2624
2628
if (I != BPS.end ())
2625
2629
return I->second ;
2626
2630
2627
2631
auto &Result = BPS[V] = None;
2628
2632
auto BitWidth = cast<IntegerType>(V->getType ())->getBitWidth ();
2629
2633
2634
+ // Prevent stack overflow by limiting the recursion depth
2635
+ if (Depth == BitPartRecursionMaxDepth) {
2636
+ LLVM_DEBUG (dbgs () << " collectBitParts max recursion depth reached.\n " );
2637
+ return Result;
2638
+ }
2639
+
2630
2640
if (Instruction *I = dyn_cast<Instruction>(V)) {
2631
2641
// If this is an or instruction, it may be an inner node of the bswap.
2632
2642
if (I->getOpcode () == Instruction::Or) {
2633
2643
auto &A = collectBitParts (I->getOperand (0 ), MatchBSwaps,
2634
- MatchBitReversals, BPS);
2644
+ MatchBitReversals, BPS, Depth + 1 );
2635
2645
auto &B = collectBitParts (I->getOperand (1 ), MatchBSwaps,
2636
- MatchBitReversals, BPS);
2646
+ MatchBitReversals, BPS, Depth + 1 );
2637
2647
if (!A || !B)
2638
2648
return Result;
2639
2649
@@ -2666,7 +2676,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
2666
2676
return Result;
2667
2677
2668
2678
auto &Res = collectBitParts (I->getOperand (0 ), MatchBSwaps,
2669
- MatchBitReversals, BPS);
2679
+ MatchBitReversals, BPS, Depth + 1 );
2670
2680
if (!Res)
2671
2681
return Result;
2672
2682
Result = Res;
@@ -2698,7 +2708,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
2698
2708
return Result;
2699
2709
2700
2710
auto &Res = collectBitParts (I->getOperand (0 ), MatchBSwaps,
2701
- MatchBitReversals, BPS);
2711
+ MatchBitReversals, BPS, Depth + 1 );
2702
2712
if (!Res)
2703
2713
return Result;
2704
2714
Result = Res;
@@ -2713,7 +2723,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
2713
2723
// If this is a zext instruction zero extend the result.
2714
2724
if (I->getOpcode () == Instruction::ZExt) {
2715
2725
auto &Res = collectBitParts (I->getOperand (0 ), MatchBSwaps,
2716
- MatchBitReversals, BPS);
2726
+ MatchBitReversals, BPS, Depth + 1 );
2717
2727
if (!Res)
2718
2728
return Result;
2719
2729
@@ -2775,7 +2785,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom(
2775
2785
2776
2786
// Try to find all the pieces corresponding to the bswap.
2777
2787
std::map<Value *, Optional<BitPart>> BPS;
2778
- auto Res = collectBitParts (I, MatchBSwaps, MatchBitReversals, BPS);
2788
+ auto Res = collectBitParts (I, MatchBSwaps, MatchBitReversals, BPS, 0 );
2779
2789
if (!Res)
2780
2790
return false ;
2781
2791
auto &BitProvenance = Res->Provenance ;
0 commit comments