1
+ // ===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
2
+ //
3
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+
9
+ #include " clang/AST/Type.h"
10
+ #include " clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
11
+ #include " clang/StaticAnalyzer/Core/BugReporter/BugType.h"
12
+ #include " clang/StaticAnalyzer/Core/Checker.h"
13
+ #include " clang/StaticAnalyzer/Core/CheckerManager.h"
14
+ #include " clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
15
+ #include " clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
16
+ #include " clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
17
+ #include " clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
18
+ #include " llvm/ADT/FoldingSet.h"
19
+ #include " llvm/ADT/StringRef.h"
20
+ #include " llvm/Support/Casting.h"
21
+ #include < optional>
22
+ #include < string_view>
23
+
24
+ #include " TaggedUnionModeling.h"
25
+
26
+ using namespace clang ;
27
+ using namespace ento ;
28
+ using namespace tagged_union_modeling ;
29
+
30
+ REGISTER_MAP_WITH_PROGRAMSTATE (VariantHeldTypeMap, const MemRegion *, QualType)
31
+
32
+ namespace clang {
33
+ namespace ento {
34
+ namespace tagged_union_modeling {
35
+
36
+ // Returns the CallEvent representing the caller of the function
37
+ // It is needed because the CallEvent class does not contain enough information
38
+ // to tell who called it. Checker context is needed.
39
+ CallEventRef<> getCaller (const CallEvent &Call, const ProgramStateRef &State) {
40
+ const auto *CallLocationContext = Call.getLocationContext ();
41
+ if (!CallLocationContext || CallLocationContext->inTopFrame ())
42
+ return nullptr ;
43
+
44
+ const auto *CallStackFrameContext = CallLocationContext->getStackFrame ();
45
+ if (!CallStackFrameContext)
46
+ return nullptr ;
47
+
48
+ CallEventManager &CEMgr = State->getStateManager ().getCallEventManager ();
49
+ return CEMgr.getCaller (CallStackFrameContext, State);
50
+ }
51
+
52
+ const CXXConstructorDecl *
53
+ getConstructorDeclarationForCall (const CallEvent &Call) {
54
+ const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
55
+ if (!ConstructorCall)
56
+ return nullptr ;
57
+
58
+ return ConstructorCall->getDecl ();
59
+ }
60
+
61
+ bool isCopyConstructorCall (const CallEvent &Call) {
62
+ if (const CXXConstructorDecl *ConstructorDecl =
63
+ getConstructorDeclarationForCall (Call))
64
+ return ConstructorDecl->isCopyConstructor ();
65
+ return false ;
66
+ }
67
+
68
+ bool isCopyAssignmentCall (const CallEvent &Call) {
69
+ const Decl *CopyAssignmentDecl = Call.getDecl ();
70
+
71
+ if (const auto *AsMethodDecl =
72
+ dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
73
+ return AsMethodDecl->isCopyAssignmentOperator ();
74
+ return false ;
75
+ }
76
+
77
+ bool isMoveConstructorCall (const CallEvent &Call) {
78
+ const CXXConstructorDecl *ConstructorDecl =
79
+ getConstructorDeclarationForCall (Call);
80
+ if (!ConstructorDecl)
81
+ return false ;
82
+
83
+ return ConstructorDecl->isMoveConstructor ();
84
+ }
85
+
86
+ bool isMoveAssignmentCall (const CallEvent &Call) {
87
+ const Decl *CopyAssignmentDecl = Call.getDecl ();
88
+
89
+ const auto *AsMethodDecl =
90
+ dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
91
+ if (!AsMethodDecl)
92
+ return false ;
93
+
94
+ return AsMethodDecl->isMoveAssignmentOperator ();
95
+ }
96
+
97
+ bool isStdType (const Type *Type, llvm::StringRef TypeName) {
98
+ auto *Decl = Type->getAsRecordDecl ();
99
+ if (!Decl)
100
+ return false ;
101
+ return (Decl->getName () == TypeName) && Decl->isInStdNamespace ();
102
+ }
103
+
104
+ bool isStdVariant (const Type *Type) {
105
+ return isStdType (Type, llvm::StringLiteral (" variant" ));
106
+ }
107
+
108
+ bool calledFromSystemHeader (const CallEvent &Call,
109
+ const ProgramStateRef &State) {
110
+ if (CallEventRef<> Caller = getCaller (Call, State))
111
+ return Caller->isInSystemHeader ();
112
+
113
+ return false ;
114
+ }
115
+
116
+ bool calledFromSystemHeader (const CallEvent &Call, CheckerContext &C) {
117
+ return calledFromSystemHeader (Call, C.getState ());
118
+ }
119
+
120
+ } // end of namespace tagged_union_modeling
121
+ } // end of namespace ento
122
+ } // end of namespace clang
123
+
124
+ static std::optional<ArrayRef<TemplateArgument>>
125
+ getTemplateArgsFromVariant (const Type *VariantType) {
126
+ const auto *TempSpecType = VariantType->getAs <TemplateSpecializationType>();
127
+ if (!TempSpecType)
128
+ return {};
129
+
130
+ return TempSpecType->template_arguments ();
131
+ }
132
+
133
+ static std::optional<QualType>
134
+ getNthTemplateTypeArgFromVariant (const Type *varType, unsigned i) {
135
+ std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
136
+ getTemplateArgsFromVariant (varType);
137
+ if (!VariantTemplates)
138
+ return {};
139
+
140
+ return (*VariantTemplates)[i].getAsType ();
141
+ }
142
+
143
+ static bool isVowel (char a) {
144
+ switch (a) {
145
+ case ' a' :
146
+ case ' e' :
147
+ case ' i' :
148
+ case ' o' :
149
+ case ' u' :
150
+ return true ;
151
+ default :
152
+ return false ;
153
+ }
154
+ }
155
+
156
+ static llvm::StringRef indefiniteArticleBasedOnVowel (char a) {
157
+ if (isVowel (a))
158
+ return " an" ;
159
+ return " a" ;
160
+ }
161
+
162
+ class StdVariantChecker : public Checker <eval::Call, check::RegionChanges> {
163
+ // Call descriptors to find relevant calls
164
+ CallDescription VariantConstructor{{" std" , " variant" , " variant" }};
165
+ CallDescription VariantAssignmentOperator{{" std" , " variant" , " operator=" }};
166
+ CallDescription StdGet{{" std" , " get" }, 1 , 1 };
167
+
168
+ BugType BadVariantType{this , " BadVariantType" , " BadVariantType" };
169
+
170
+ public:
171
+ ProgramStateRef checkRegionChanges (ProgramStateRef State,
172
+ const InvalidatedSymbols *,
173
+ ArrayRef<const MemRegion *>,
174
+ ArrayRef<const MemRegion *> Regions,
175
+ const LocationContext *,
176
+ const CallEvent *Call) const {
177
+ return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
178
+ Call, State, Regions);
179
+ }
180
+
181
+ bool evalCall (const CallEvent &Call, CheckerContext &C) const {
182
+ // Check if the call was not made from a system header. If it was then
183
+ // we do an early return because it is part of the implementation.
184
+ if (calledFromSystemHeader (Call, C))
185
+ return false ;
186
+
187
+ if (StdGet.matches (Call))
188
+ return handleStdGetCall (Call, C);
189
+
190
+ // First check if a constructor call is happening. If it is a
191
+ // constructor call, check if it is an std::variant constructor call.
192
+ bool IsVariantConstructor =
193
+ isa<CXXConstructorCall>(Call) && VariantConstructor.matches (Call);
194
+ bool IsVariantAssignmentOperatorCall =
195
+ isa<CXXMemberOperatorCall>(Call) &&
196
+ VariantAssignmentOperator.matches (Call);
197
+
198
+ if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
199
+ if (Call.getNumArgs () == 0 && IsVariantConstructor) {
200
+ handleDefaultConstructor (cast<CXXConstructorCall>(&Call), C);
201
+ return true ;
202
+ }
203
+
204
+ // FIXME Later this checker should be extended to handle constructors
205
+ // with multiple arguments.
206
+ if (Call.getNumArgs () != 1 )
207
+ return false ;
208
+
209
+ SVal ThisSVal;
210
+ if (IsVariantConstructor) {
211
+ const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
212
+ ThisSVal = AsConstructorCall.getCXXThisVal ();
213
+ } else if (IsVariantAssignmentOperatorCall) {
214
+ const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
215
+ ThisSVal = AsMemberOpCall.getCXXThisVal ();
216
+ } else {
217
+ return false ;
218
+ }
219
+
220
+ handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
221
+ return true ;
222
+ }
223
+ return false ;
224
+ }
225
+
226
+ private:
227
+ // The default constructed std::variant must be handled separately
228
+ // by default the std::variant is going to hold a default constructed instance
229
+ // of the first type of the possible types
230
+ void handleDefaultConstructor (const CXXConstructorCall *ConstructorCall,
231
+ CheckerContext &C) const {
232
+ SVal ThisSVal = ConstructorCall->getCXXThisVal ();
233
+
234
+ const auto *const ThisMemRegion = ThisSVal.getAsRegion ();
235
+ if (!ThisMemRegion)
236
+ return ;
237
+
238
+ std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant (
239
+ ThisSVal.getType (C.getASTContext ())->getPointeeType ().getTypePtr (), 0 );
240
+ if (!DefaultType)
241
+ return ;
242
+
243
+ ProgramStateRef State = ConstructorCall->getState ();
244
+ State = State->set <VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
245
+ C.addTransition (State);
246
+ }
247
+
248
+ bool handleStdGetCall (const CallEvent &Call, CheckerContext &C) const {
249
+ ProgramStateRef State = Call.getState ();
250
+
251
+ const auto &ArgType = Call.getArgSVal (0 )
252
+ .getType (C.getASTContext ())
253
+ ->getPointeeType ()
254
+ .getTypePtr ();
255
+ // We have to make sure that the argument is an std::variant.
256
+ // There is another std::get with std::pair argument
257
+ if (!isStdVariant (ArgType))
258
+ return false ;
259
+
260
+ // Get the mem region of the argument std::variant and look up the type
261
+ // information that we know about it.
262
+ const MemRegion *ArgMemRegion = Call.getArgSVal (0 ).getAsRegion ();
263
+ const QualType *StoredType = State->get <VariantHeldTypeMap>(ArgMemRegion);
264
+ if (!StoredType)
265
+ return false ;
266
+
267
+ const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr ());
268
+ const FunctionDecl *FD = CE->getDirectCallee ();
269
+ if (FD->getTemplateSpecializationArgs ()->size () < 1 )
270
+ return false ;
271
+
272
+ const auto &TypeOut = FD->getTemplateSpecializationArgs ()->asArray ()[0 ];
273
+ // std::get's first template parameter can be the type we want to get
274
+ // out of the std::variant or a natural number which is the position of
275
+ // the requested type in the argument type list of the std::variant's
276
+ // argument.
277
+ QualType RetrievedType;
278
+ switch (TypeOut.getKind ()) {
279
+ case TemplateArgument::ArgKind::Type:
280
+ RetrievedType = TypeOut.getAsType ();
281
+ break ;
282
+ case TemplateArgument::ArgKind::Integral:
283
+ // In the natural number case we look up which type corresponds to the
284
+ // number.
285
+ if (std::optional<QualType> NthTemplate =
286
+ getNthTemplateTypeArgFromVariant (
287
+ ArgType, TypeOut.getAsIntegral ().getSExtValue ())) {
288
+ RetrievedType = *NthTemplate;
289
+ break ;
290
+ }
291
+ [[fallthrough]];
292
+ default :
293
+ return false ;
294
+ }
295
+
296
+ QualType RetrievedCanonicalType = RetrievedType.getCanonicalType ();
297
+ QualType StoredCanonicalType = StoredType->getCanonicalType ();
298
+ if (RetrievedCanonicalType == StoredCanonicalType)
299
+ return true ;
300
+
301
+ ExplodedNode *ErrNode = C.generateNonFatalErrorNode ();
302
+ if (!ErrNode)
303
+ return false ;
304
+ llvm::SmallString<128 > Str;
305
+ llvm::raw_svector_ostream OS (Str);
306
+ std::string StoredTypeName = StoredType->getAsString ();
307
+ std::string RetrievedTypeName = RetrievedType.getAsString ();
308
+ OS << " std::variant " << ArgMemRegion->getDescriptiveName () << " held "
309
+ << indefiniteArticleBasedOnVowel (StoredTypeName[0 ]) << " \' "
310
+ << StoredTypeName << " \' , not "
311
+ << indefiniteArticleBasedOnVowel (RetrievedTypeName[0 ]) << " \' "
312
+ << RetrievedTypeName << " \' " ;
313
+ auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str (),
314
+ ErrNode);
315
+ C.emitReport (std::move (R));
316
+ return true ;
317
+ }
318
+ };
319
+
320
+ bool clang::ento::shouldRegisterStdVariantChecker (
321
+ clang::ento::CheckerManager const &mgr) {
322
+ return true ;
323
+ }
324
+
325
+ void clang::ento::registerStdVariantChecker (clang::ento::CheckerManager &mgr) {
326
+ mgr.registerChecker <StdVariantChecker>();
327
+ }
0 commit comments