Skip to content

Commit 1c90de5

Browse files
authored
[analyzer] Allow recursive functions to be trivial. (#91876)
1 parent 0338c55 commit 1c90de5

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,43 @@ class TrivialFunctionAnalysisVisitor
271271

272272
TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}
273273

274+
bool IsFunctionTrivial(const Decl *D) {
275+
auto CacheIt = Cache.find(D);
276+
if (CacheIt != Cache.end())
277+
return CacheIt->second;
278+
279+
// Treat a recursive function call to be trivial until proven otherwise.
280+
auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(D, true));
281+
if (!IsNew)
282+
return RecursiveIt->second;
283+
284+
bool Result = [&]() {
285+
if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
286+
for (auto *CtorInit : CtorDecl->inits()) {
287+
if (!Visit(CtorInit->getInit()))
288+
return false;
289+
}
290+
}
291+
const Stmt *Body = D->getBody();
292+
if (!Body)
293+
return false;
294+
return Visit(Body);
295+
}();
296+
297+
if (!Result) {
298+
// D and its mutually recursive callers are all non-trivial.
299+
for (auto &It : RecursiveFn)
300+
It.second = false;
301+
}
302+
RecursiveIt = RecursiveFn.find(D);
303+
assert(RecursiveIt != RecursiveFn.end());
304+
Result = RecursiveIt->second;
305+
RecursiveFn.erase(RecursiveIt);
306+
Cache[D] = Result;
307+
308+
return Result;
309+
}
310+
274311
bool VisitStmt(const Stmt *S) {
275312
// All statements are non-trivial unless overriden later.
276313
// Don't even recurse into children by default.
@@ -368,7 +405,7 @@ class TrivialFunctionAnalysisVisitor
368405
Name == "bitwise_cast" || Name.find("__builtin") == 0)
369406
return true;
370407

371-
return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
408+
return IsFunctionTrivial(Callee);
372409
}
373410

374411
bool
@@ -403,7 +440,7 @@ class TrivialFunctionAnalysisVisitor
403440
return true;
404441

405442
// Recursively descend into the callee to confirm that it's trivial as well.
406-
return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
443+
return IsFunctionTrivial(Callee);
407444
}
408445

409446
bool VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) {
@@ -413,7 +450,7 @@ class TrivialFunctionAnalysisVisitor
413450
if (!Callee)
414451
return false;
415452
// Recursively descend into the callee to confirm that it's trivial as well.
416-
return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
453+
return IsFunctionTrivial(Callee);
417454
}
418455

419456
bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
@@ -439,7 +476,7 @@ class TrivialFunctionAnalysisVisitor
439476
}
440477

441478
// Recursively descend into the callee to confirm that it's trivial.
442-
return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache);
479+
return IsFunctionTrivial(CE->getConstructor());
443480
}
444481

445482
bool VisitCXXNewExpr(const CXXNewExpr *NE) { return VisitChildren(NE); }
@@ -513,36 +550,13 @@ class TrivialFunctionAnalysisVisitor
513550

514551
private:
515552
CacheTy &Cache;
553+
CacheTy RecursiveFn;
516554
};
517555

518556
bool TrivialFunctionAnalysis::isTrivialImpl(
519557
const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
520-
// If the function isn't in the cache, conservatively assume that
521-
// it's not trivial until analysis completes. This makes every recursive
522-
// function non-trivial. This also guarantees that each function
523-
// will be scanned at most once.
524-
auto [It, IsNew] = Cache.insert(std::make_pair(D, false));
525-
if (!IsNew)
526-
return It->second;
527-
528558
TrivialFunctionAnalysisVisitor V(Cache);
529-
530-
if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
531-
for (auto *CtorInit : CtorDecl->inits()) {
532-
if (!V.Visit(CtorInit->getInit()))
533-
return false;
534-
}
535-
}
536-
537-
const Stmt *Body = D->getBody();
538-
if (!Body)
539-
return false;
540-
541-
bool Result = V.Visit(Body);
542-
if (Result)
543-
Cache[D] = true;
544-
545-
return Result;
559+
return V.IsFunctionTrivial(D);
546560
}
547561

548562
bool TrivialFunctionAnalysis::isTrivialImpl(

clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,18 @@ class RefCounted {
231231
void method();
232232
void someFunction();
233233
int otherFunction();
234+
unsigned recursiveTrivialFunction(int n) { return !n ? 1 : recursiveTrivialFunction(n - 1); }
235+
unsigned recursiveComplexFunction(int n) { return !n ? otherFunction() : recursiveComplexFunction(n - 1); }
236+
unsigned mutuallyRecursiveFunction1(int n) { return n < 0 ? 1 : (n % 2 ? mutuallyRecursiveFunction2(n - 2) : mutuallyRecursiveFunction1(n - 1)); }
237+
unsigned mutuallyRecursiveFunction2(int n) { return n < 0 ? 1 : (n % 3 ? mutuallyRecursiveFunction2(n - 3) : mutuallyRecursiveFunction1(n - 2)); }
238+
unsigned mutuallyRecursiveFunction3(int n) { return n < 0 ? 1 : (n % 5 ? mutuallyRecursiveFunction3(n - 5) : mutuallyRecursiveFunction4(n - 3)); }
239+
unsigned mutuallyRecursiveFunction4(int n) { return n < 0 ? 1 : (n % 7 ? otherFunction() : mutuallyRecursiveFunction3(n - 3)); }
240+
unsigned recursiveFunction5(unsigned n) { return n > 100 ? 2 : (n % 2 ? recursiveFunction5(n + 1) : recursiveFunction6(n + 2)); }
241+
unsigned recursiveFunction6(unsigned n) { return n > 100 ? 3 : (n % 2 ? recursiveFunction6(n % 7) : recursiveFunction7(n % 5)); }
242+
unsigned recursiveFunction7(unsigned n) { return n > 100 ? 5 : recursiveFunction7(n * 5); }
243+
244+
void mutuallyRecursive8() { mutuallyRecursive9(); someFunction(); }
245+
void mutuallyRecursive9() { mutuallyRecursive8(); }
234246

235247
int trivial1() { return 123; }
236248
float trivial2() { return 0.3; }
@@ -498,6 +510,24 @@ class UnrelatedClass {
498510
RefCounted::singleton().trivial18(); // no-warning
499511
RefCounted::singleton().someFunction(); // no-warning
500512

513+
getFieldTrivial().recursiveTrivialFunction(7); // no-warning
514+
getFieldTrivial().recursiveComplexFunction(9);
515+
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
516+
getFieldTrivial().mutuallyRecursiveFunction1(11); // no-warning
517+
getFieldTrivial().mutuallyRecursiveFunction2(13); // no-warning
518+
getFieldTrivial().mutuallyRecursiveFunction3(17);
519+
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
520+
getFieldTrivial().mutuallyRecursiveFunction4(19);
521+
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
522+
getFieldTrivial().recursiveFunction5(23); // no-warning
523+
getFieldTrivial().recursiveFunction6(29); // no-warning
524+
getFieldTrivial().recursiveFunction7(31); // no-warning
525+
526+
getFieldTrivial().mutuallyRecursive8();
527+
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
528+
getFieldTrivial().mutuallyRecursive9();
529+
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
530+
501531
getFieldTrivial().someFunction();
502532
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
503533
getFieldTrivial().nonTrivial1();

0 commit comments

Comments
 (0)