Skip to content

[analyzer] Allow recursive functions to be trivial. #91876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,43 @@ class TrivialFunctionAnalysisVisitor

TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}

bool IsFunctionTrivial(const Decl *D) {
auto CacheIt = Cache.find(D);
if (CacheIt != Cache.end())
return CacheIt->second;

// Treat a recursive function call to be trivial until proven otherwise.
auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(D, true));
if (!IsNew)
return RecursiveIt->second;

bool Result = [&]() {
if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
for (auto *CtorInit : CtorDecl->inits()) {
if (!Visit(CtorInit->getInit()))
return false;
}
}
const Stmt *Body = D->getBody();
if (!Body)
return false;
return Visit(Body);
}();

if (!Result) {
// D and its mutually recursive callers are all non-trivial.
for (auto &It : RecursiveFn)
It.second = false;
}
RecursiveIt = RecursiveFn.find(D);
assert(RecursiveIt != RecursiveFn.end());
Result = RecursiveIt->second;
RecursiveFn.erase(RecursiveIt);
Cache[D] = Result;

return Result;
}

bool VisitStmt(const Stmt *S) {
// All statements are non-trivial unless overriden later.
// Don't even recurse into children by default.
Expand Down Expand Up @@ -368,7 +405,7 @@ class TrivialFunctionAnalysisVisitor
Name == "bitwise_cast" || Name.find("__builtin") == 0)
return true;

return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
return IsFunctionTrivial(Callee);
}

bool
Expand Down Expand Up @@ -403,7 +440,7 @@ class TrivialFunctionAnalysisVisitor
return true;

// Recursively descend into the callee to confirm that it's trivial as well.
return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
return IsFunctionTrivial(Callee);
}

bool VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) {
Expand All @@ -413,7 +450,7 @@ class TrivialFunctionAnalysisVisitor
if (!Callee)
return false;
// Recursively descend into the callee to confirm that it's trivial as well.
return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
return IsFunctionTrivial(Callee);
}

bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
Expand All @@ -439,7 +476,7 @@ class TrivialFunctionAnalysisVisitor
}

// Recursively descend into the callee to confirm that it's trivial.
return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache);
return IsFunctionTrivial(CE->getConstructor());
}

bool VisitCXXNewExpr(const CXXNewExpr *NE) { return VisitChildren(NE); }
Expand Down Expand Up @@ -513,36 +550,13 @@ class TrivialFunctionAnalysisVisitor

private:
CacheTy &Cache;
CacheTy RecursiveFn;
};

bool TrivialFunctionAnalysis::isTrivialImpl(
const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
// If the function isn't in the cache, conservatively assume that
// it's not trivial until analysis completes. This makes every recursive
// function non-trivial. This also guarantees that each function
// will be scanned at most once.
auto [It, IsNew] = Cache.insert(std::make_pair(D, false));
if (!IsNew)
return It->second;

TrivialFunctionAnalysisVisitor V(Cache);

if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
for (auto *CtorInit : CtorDecl->inits()) {
if (!V.Visit(CtorInit->getInit()))
return false;
}
}

const Stmt *Body = D->getBody();
if (!Body)
return false;

bool Result = V.Visit(Body);
if (Result)
Cache[D] = true;

return Result;
return V.IsFunctionTrivial(D);
}

bool TrivialFunctionAnalysis::isTrivialImpl(
Expand Down
30 changes: 30 additions & 0 deletions clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,18 @@ class RefCounted {
void method();
void someFunction();
int otherFunction();
unsigned recursiveTrivialFunction(int n) { return !n ? 1 : recursiveTrivialFunction(n - 1); }
unsigned recursiveComplexFunction(int n) { return !n ? otherFunction() : recursiveComplexFunction(n - 1); }
unsigned mutuallyRecursiveFunction1(int n) { return n < 0 ? 1 : (n % 2 ? mutuallyRecursiveFunction2(n - 2) : mutuallyRecursiveFunction1(n - 1)); }
unsigned mutuallyRecursiveFunction2(int n) { return n < 0 ? 1 : (n % 3 ? mutuallyRecursiveFunction2(n - 3) : mutuallyRecursiveFunction1(n - 2)); }
unsigned mutuallyRecursiveFunction3(int n) { return n < 0 ? 1 : (n % 5 ? mutuallyRecursiveFunction3(n - 5) : mutuallyRecursiveFunction4(n - 3)); }
unsigned mutuallyRecursiveFunction4(int n) { return n < 0 ? 1 : (n % 7 ? otherFunction() : mutuallyRecursiveFunction3(n - 3)); }
unsigned recursiveFunction5(unsigned n) { return n > 100 ? 2 : (n % 2 ? recursiveFunction5(n + 1) : recursiveFunction6(n + 2)); }
unsigned recursiveFunction6(unsigned n) { return n > 100 ? 3 : (n % 2 ? recursiveFunction6(n % 7) : recursiveFunction7(n % 5)); }
unsigned recursiveFunction7(unsigned n) { return n > 100 ? 5 : recursiveFunction7(n * 5); }

void mutuallyRecursive8() { mutuallyRecursive9(); someFunction(); }
void mutuallyRecursive9() { mutuallyRecursive8(); }

int trivial1() { return 123; }
float trivial2() { return 0.3; }
Expand Down Expand Up @@ -498,6 +510,24 @@ class UnrelatedClass {
RefCounted::singleton().trivial18(); // no-warning
RefCounted::singleton().someFunction(); // no-warning

getFieldTrivial().recursiveTrivialFunction(7); // no-warning
getFieldTrivial().recursiveComplexFunction(9);
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
getFieldTrivial().mutuallyRecursiveFunction1(11); // no-warning
getFieldTrivial().mutuallyRecursiveFunction2(13); // no-warning
getFieldTrivial().mutuallyRecursiveFunction3(17);
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
getFieldTrivial().mutuallyRecursiveFunction4(19);
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
getFieldTrivial().recursiveFunction5(23); // no-warning
getFieldTrivial().recursiveFunction6(29); // no-warning
getFieldTrivial().recursiveFunction7(31); // no-warning

getFieldTrivial().mutuallyRecursive8();
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
getFieldTrivial().mutuallyRecursive9();
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}

getFieldTrivial().someFunction();
// expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}}
getFieldTrivial().nonTrivial1();
Expand Down
Loading