Skip to content

Commit 1948d22

Browse files
spaitsGabor Spaits
authored andcommitted
[analyzer] Add std::variant checker
Adding a checker that checks for bad std::variant type access.
1 parent c300884 commit 1948d22

File tree

7 files changed

+917
-1
lines changed

7 files changed

+917
-1
lines changed

clang/include/clang/StaticAnalyzer/Checkers/Checkers.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ def C11LockChecker : Checker<"C11Lock">,
318318
Dependencies<[PthreadLockBase]>,
319319
Documentation<HasDocumentation>;
320320

321+
def StdVariantChecker : Checker<"StdVariant">,
322+
HelpText<"Check for bad type access for std::variant.">,
323+
Documentation<NotDocumented>;
324+
321325
} // end "alpha.core"
322326

323327
//===----------------------------------------------------------------------===//

clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ add_clang_library(clangStaticAnalyzerCheckers
108108
SmartPtrModeling.cpp
109109
StackAddrEscapeChecker.cpp
110110
StdLibraryFunctionsChecker.cpp
111+
StdVariantChecker.cpp
111112
STLAlgorithmModeling.cpp
112113
StreamChecker.cpp
113114
StringChecker.cpp
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
//===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "clang/AST/Type.h"
10+
#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
11+
#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
12+
#include "clang/StaticAnalyzer/Core/Checker.h"
13+
#include "clang/StaticAnalyzer/Core/CheckerManager.h"
14+
#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
15+
#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
16+
#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
17+
#include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
18+
#include "llvm/ADT/FoldingSet.h"
19+
#include "llvm/ADT/StringRef.h"
20+
#include "llvm/Support/Casting.h"
21+
#include <optional>
22+
#include <string_view>
23+
24+
#include "TaggedUnionModeling.h"
25+
26+
using namespace clang;
27+
using namespace ento;
28+
using namespace tagged_union_modeling;
29+
30+
REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
31+
32+
namespace clang {
33+
namespace ento {
34+
namespace tagged_union_modeling {
35+
36+
// Returns the CallEvent representing the caller of the function
37+
// It is needed because the CallEvent class does not contain enough information
38+
// to tell who called it. Checker context is needed.
39+
CallEventRef<> getCaller(const CallEvent &Call, const ProgramStateRef &State) {
40+
const auto *CallLocationContext = Call.getLocationContext();
41+
if (!CallLocationContext || CallLocationContext->inTopFrame())
42+
return nullptr;
43+
44+
const auto *CallStackFrameContext = CallLocationContext->getStackFrame();
45+
if (!CallStackFrameContext)
46+
return nullptr;
47+
48+
CallEventManager &CEMgr = State->getStateManager().getCallEventManager();
49+
return CEMgr.getCaller(CallStackFrameContext, State);
50+
}
51+
52+
const CXXConstructorDecl *
53+
getConstructorDeclarationForCall(const CallEvent &Call) {
54+
const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
55+
if (!ConstructorCall)
56+
return nullptr;
57+
58+
return ConstructorCall->getDecl();
59+
}
60+
61+
bool isCopyConstructorCall(const CallEvent &Call) {
62+
if (const CXXConstructorDecl *ConstructorDecl =
63+
getConstructorDeclarationForCall(Call))
64+
return ConstructorDecl->isCopyConstructor();
65+
return false;
66+
}
67+
68+
bool isCopyAssignmentCall(const CallEvent &Call) {
69+
const Decl *CopyAssignmentDecl = Call.getDecl();
70+
71+
if (const auto *AsMethodDecl =
72+
dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
73+
return AsMethodDecl->isCopyAssignmentOperator();
74+
return false;
75+
}
76+
77+
bool isMoveConstructorCall(const CallEvent &Call) {
78+
const CXXConstructorDecl *ConstructorDecl =
79+
getConstructorDeclarationForCall(Call);
80+
if (!ConstructorDecl)
81+
return false;
82+
83+
return ConstructorDecl->isMoveConstructor();
84+
}
85+
86+
bool isMoveAssignmentCall(const CallEvent &Call) {
87+
const Decl *CopyAssignmentDecl = Call.getDecl();
88+
89+
const auto *AsMethodDecl =
90+
dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
91+
if (!AsMethodDecl)
92+
return false;
93+
94+
return AsMethodDecl->isMoveAssignmentOperator();
95+
}
96+
97+
bool isStdType(const Type *Type, llvm::StringRef TypeName) {
98+
auto *Decl = Type->getAsRecordDecl();
99+
if (!Decl)
100+
return false;
101+
return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
102+
}
103+
104+
bool isStdVariant(const Type *Type) {
105+
return isStdType(Type, llvm::StringLiteral("variant"));
106+
}
107+
108+
bool calledFromSystemHeader(const CallEvent &Call,
109+
const ProgramStateRef &State) {
110+
if (CallEventRef<> Caller = getCaller(Call, State))
111+
return Caller->isInSystemHeader();
112+
113+
return false;
114+
}
115+
116+
bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C) {
117+
return calledFromSystemHeader(Call, C.getState());
118+
}
119+
120+
} // end of namespace tagged_union_modeling
121+
} // end of namespace ento
122+
} // end of namespace clang
123+
124+
static std::optional<ArrayRef<TemplateArgument>>
125+
getTemplateArgsFromVariant(const Type *VariantType) {
126+
const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
127+
if (!TempSpecType)
128+
return {};
129+
130+
return TempSpecType->template_arguments();
131+
}
132+
133+
static std::optional<QualType>
134+
getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
135+
std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
136+
getTemplateArgsFromVariant(varType);
137+
if (!VariantTemplates)
138+
return {};
139+
140+
return (*VariantTemplates)[i].getAsType();
141+
}
142+
143+
static bool isVowel(char a) {
144+
switch (a) {
145+
case 'a':
146+
case 'e':
147+
case 'i':
148+
case 'o':
149+
case 'u':
150+
return true;
151+
default:
152+
return false;
153+
}
154+
}
155+
156+
static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
157+
if (isVowel(a))
158+
return "an";
159+
return "a";
160+
}
161+
162+
class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
163+
// Call descriptors to find relevant calls
164+
CallDescription VariantConstructor{{"std", "variant", "variant"}};
165+
CallDescription VariantAssignmentOperator{{"std", "variant", "operator="}};
166+
CallDescription StdGet{{"std", "get"}, 1, 1};
167+
168+
BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
169+
170+
public:
171+
ProgramStateRef checkRegionChanges(ProgramStateRef State,
172+
const InvalidatedSymbols *,
173+
ArrayRef<const MemRegion *>,
174+
ArrayRef<const MemRegion *> Regions,
175+
const LocationContext *,
176+
const CallEvent *Call) const {
177+
return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
178+
Call, State, Regions);
179+
}
180+
181+
bool evalCall(const CallEvent &Call, CheckerContext &C) const {
182+
// Check if the call was not made from a system header. If it was then
183+
// we do an early return because it is part of the implementation.
184+
if (calledFromSystemHeader(Call, C))
185+
return false;
186+
187+
if (StdGet.matches(Call))
188+
return handleStdGetCall(Call, C);
189+
190+
// First check if a constructor call is happening. If it is a
191+
// constructor call, check if it is an std::variant constructor call.
192+
bool IsVariantConstructor =
193+
isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
194+
bool IsVariantAssignmentOperatorCall =
195+
isa<CXXMemberOperatorCall>(Call) &&
196+
VariantAssignmentOperator.matches(Call);
197+
198+
if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
199+
if (Call.getNumArgs() == 0 && IsVariantConstructor) {
200+
handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
201+
return true;
202+
}
203+
204+
// FIXME Later this checker should be extended to handle constructors
205+
// with multiple arguments.
206+
if (Call.getNumArgs() != 1)
207+
return false;
208+
209+
SVal ThisSVal;
210+
if (IsVariantConstructor) {
211+
const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
212+
ThisSVal = AsConstructorCall.getCXXThisVal();
213+
} else if (IsVariantAssignmentOperatorCall) {
214+
const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
215+
ThisSVal = AsMemberOpCall.getCXXThisVal();
216+
} else {
217+
return false;
218+
}
219+
220+
handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
221+
return true;
222+
}
223+
return false;
224+
}
225+
226+
private:
227+
// The default constructed std::variant must be handled separately
228+
// by default the std::variant is going to hold a default constructed instance
229+
// of the first type of the possible types
230+
void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
231+
CheckerContext &C) const {
232+
SVal ThisSVal = ConstructorCall->getCXXThisVal();
233+
234+
const auto *const ThisMemRegion = ThisSVal.getAsRegion();
235+
if (!ThisMemRegion)
236+
return;
237+
238+
std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
239+
ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
240+
if (!DefaultType)
241+
return;
242+
243+
ProgramStateRef State = ConstructorCall->getState();
244+
State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
245+
C.addTransition(State);
246+
}
247+
248+
bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
249+
ProgramStateRef State = Call.getState();
250+
251+
const auto &ArgType = Call.getArgSVal(0)
252+
.getType(C.getASTContext())
253+
->getPointeeType()
254+
.getTypePtr();
255+
// We have to make sure that the argument is an std::variant.
256+
// There is another std::get with std::pair argument
257+
if (!isStdVariant(ArgType))
258+
return false;
259+
260+
// Get the mem region of the argument std::variant and look up the type
261+
// information that we know about it.
262+
const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
263+
const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
264+
if (!StoredType)
265+
return false;
266+
267+
const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
268+
const FunctionDecl *FD = CE->getDirectCallee();
269+
if (FD->getTemplateSpecializationArgs()->size() < 1)
270+
return false;
271+
272+
const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
273+
// std::get's first template parameter can be the type we want to get
274+
// out of the std::variant or a natural number which is the position of
275+
// the requested type in the argument type list of the std::variant's
276+
// argument.
277+
QualType RetrievedType;
278+
switch (TypeOut.getKind()) {
279+
case TemplateArgument::ArgKind::Type:
280+
RetrievedType = TypeOut.getAsType();
281+
break;
282+
case TemplateArgument::ArgKind::Integral:
283+
// In the natural number case we look up which type corresponds to the
284+
// number.
285+
if (std::optional<QualType> NthTemplate =
286+
getNthTemplateTypeArgFromVariant(
287+
ArgType, TypeOut.getAsIntegral().getSExtValue())) {
288+
RetrievedType = *NthTemplate;
289+
break;
290+
}
291+
[[fallthrough]];
292+
default:
293+
return false;
294+
}
295+
296+
QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
297+
QualType StoredCanonicalType = StoredType->getCanonicalType();
298+
if (RetrievedCanonicalType == StoredCanonicalType)
299+
return true;
300+
301+
ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
302+
if (!ErrNode)
303+
return false;
304+
llvm::SmallString<128> Str;
305+
llvm::raw_svector_ostream OS(Str);
306+
std::string StoredTypeName = StoredType->getAsString();
307+
std::string RetrievedTypeName = RetrievedType.getAsString();
308+
OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
309+
<< indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'"
310+
<< StoredTypeName << "\', not "
311+
<< indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'"
312+
<< RetrievedTypeName << "\'";
313+
auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str(),
314+
ErrNode);
315+
C.emitReport(std::move(R));
316+
return true;
317+
}
318+
};
319+
320+
bool clang::ento::shouldRegisterStdVariantChecker(
321+
clang::ento::CheckerManager const &mgr) {
322+
return true;
323+
}
324+
325+
void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
326+
mgr.registerChecker<StdVariantChecker>();
327+
}

0 commit comments

Comments
 (0)