@@ -64,39 +64,117 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
64
64
return false ;
65
65
}
66
66
67
+ static const CXXRecordDecl *getOptionalBaseClass (const CXXRecordDecl *RD) {
68
+ if (RD == nullptr )
69
+ return nullptr ;
70
+ if (hasOptionalClassName (*RD))
71
+ return RD;
72
+
73
+ if (!RD->hasDefinition ())
74
+ return nullptr ;
75
+
76
+ for (const CXXBaseSpecifier &Base : RD->bases ())
77
+ if (const CXXRecordDecl *BaseClass =
78
+ getOptionalBaseClass (Base.getType ()->getAsCXXRecordDecl ()))
79
+ return BaseClass;
80
+
81
+ return nullptr ;
82
+ }
83
+
67
84
namespace {
68
85
69
86
using namespace ::clang::ast_matchers;
70
87
using LatticeTransferState = TransferState<NoopLattice>;
71
88
72
- AST_MATCHER (CXXRecordDecl, hasOptionalClassNameMatcher) {
73
- return hasOptionalClassName (Node);
89
+ AST_MATCHER (CXXRecordDecl, optionalClass) { return hasOptionalClassName (Node); }
90
+
91
+ AST_MATCHER (CXXRecordDecl, optionalOrDerivedClass) {
92
+ return getOptionalBaseClass (&Node) != nullptr ;
74
93
}
75
94
76
- DeclarationMatcher optionalClass () {
77
- return classTemplateSpecializationDecl (
78
- hasOptionalClassNameMatcher (),
79
- hasTemplateArgument (0 , refersToType (type ().bind (" T" ))));
95
+ auto desugarsToOptionalType () {
96
+ return hasUnqualifiedDesugaredType (
97
+ recordType (hasDeclaration (cxxRecordDecl (optionalClass ()))));
80
98
}
81
99
82
- auto optionalOrAliasType () {
100
+ auto desugarsToOptionalOrDerivedType () {
83
101
return hasUnqualifiedDesugaredType (
84
- recordType (hasDeclaration (optionalClass ( ))));
102
+ recordType (hasDeclaration (cxxRecordDecl ( optionalOrDerivedClass () ))));
85
103
}
86
104
87
- // / Matches any of the spellings of the optional types and sugar, aliases, etc.
88
- auto hasOptionalType () { return hasType (optionalOrAliasType ()); }
105
+ auto hasOptionalType () { return hasType (desugarsToOptionalType ()); }
106
+
107
+ // / Matches any of the spellings of the optional types and sugar, aliases,
108
+ // / derived classes, etc.
109
+ auto hasOptionalOrDerivedType () {
110
+ return hasType (desugarsToOptionalOrDerivedType ());
111
+ }
112
+
113
+ QualType getPublicType (const Expr *E) {
114
+ auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens ());
115
+ if (Cast == nullptr || Cast->getCastKind () != CK_UncheckedDerivedToBase) {
116
+ QualType Ty = E->getType ();
117
+ if (Ty->isPointerType ())
118
+ return Ty->getPointeeType ();
119
+ return Ty;
120
+ }
121
+
122
+ QualType Ty = getPublicType (Cast->getSubExpr ());
123
+
124
+ // Is `Ty` the type of `*this`? In this special case, we can upcast to the
125
+ // base class even if the base is non-public.
126
+ bool TyIsThisType = isa<CXXThisExpr>(Cast->getSubExpr ());
127
+
128
+ for (const CXXBaseSpecifier *Base : Cast->path ()) {
129
+ if (Base->getAccessSpecifier () != AS_public && !TyIsThisType)
130
+ break ;
131
+ Ty = Base->getType ();
132
+ TyIsThisType = false ;
133
+ }
134
+
135
+ return Ty;
136
+ }
137
+
138
+ // Returns the least-derived type for the receiver of `MCE` that
139
+ // `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
140
+ // Effectively, we upcast until we reach a non-public base class, unless that
141
+ // base is a base of `*this`.
142
+ //
143
+ // This is needed to correctly match methods called on types derived from
144
+ // `std::optional`.
145
+ //
146
+ // Say we have a `struct Derived : public std::optional<int> {} d;` For a call
147
+ // `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
148
+ //
149
+ // ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
150
+ // | <UncheckedDerivedToBase (optional -> __optional_storage_base)>
151
+ // `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
152
+ //
153
+ // The type of the implicit object argument is `__optional_storage_base`
154
+ // (since this is the internal type that `has_value()` is declared on). If we
155
+ // call `IgnoreParenImpCasts()` on the implicit object argument, we get the
156
+ // `DeclRefExpr`, which has type `Derived`. Neither of these types is
157
+ // `optional`, and hence neither is sufficient for querying whether we are
158
+ // calling a method on `optional`.
159
+ //
160
+ // Instead, starting with the most derived type, we need to follow the chain of
161
+ // casts
162
+ QualType getPublicReceiverType (const CXXMemberCallExpr &MCE) {
163
+ return getPublicType (MCE.getImplicitObjectArgument ());
164
+ }
165
+
166
+ AST_MATCHER_P (CXXMemberCallExpr, publicReceiverType,
167
+ ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
168
+ return InnerMatcher.matches (getPublicReceiverType (Node), Finder, Builder);
169
+ }
89
170
90
171
auto isOptionalMemberCallWithNameMatcher (
91
172
ast_matchers::internal::Matcher<NamedDecl> matcher,
92
173
const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
93
- auto Exception = unless (Ignorable ? expr (anyOf (*Ignorable, cxxThisExpr ()))
94
- : cxxThisExpr ());
95
- return cxxMemberCallExpr (
96
- on (expr (Exception,
97
- anyOf (hasOptionalType (),
98
- hasType (pointerType (pointee (optionalOrAliasType ())))))),
99
- callee (cxxMethodDecl (matcher)));
174
+ return cxxMemberCallExpr (Ignorable ? on (expr (unless (*Ignorable)))
175
+ : anything (),
176
+ publicReceiverType (desugarsToOptionalType ()),
177
+ callee (cxxMethodDecl (matcher)));
100
178
}
101
179
102
180
auto isOptionalOperatorCallWithName (
@@ -129,19 +207,19 @@ auto inPlaceClass() {
129
207
130
208
auto isOptionalNulloptConstructor () {
131
209
return cxxConstructExpr (
132
- hasOptionalType (),
210
+ hasOptionalOrDerivedType (),
133
211
hasDeclaration (cxxConstructorDecl (parameterCountIs (1 ),
134
212
hasParameter (0 , hasNulloptType ()))));
135
213
}
136
214
137
215
auto isOptionalInPlaceConstructor () {
138
- return cxxConstructExpr (hasOptionalType (),
216
+ return cxxConstructExpr (hasOptionalOrDerivedType (),
139
217
hasArgument (0 , hasType (inPlaceClass ())));
140
218
}
141
219
142
220
auto isOptionalValueOrConversionConstructor () {
143
221
return cxxConstructExpr (
144
- hasOptionalType (),
222
+ hasOptionalOrDerivedType (),
145
223
unless (hasDeclaration (
146
224
cxxConstructorDecl (anyOf (isCopyConstructor (), isMoveConstructor ())))),
147
225
argumentCountIs (1 ), hasArgument (0 , unless (hasNulloptType ())));
@@ -150,28 +228,30 @@ auto isOptionalValueOrConversionConstructor() {
150
228
auto isOptionalValueOrConversionAssignment () {
151
229
return cxxOperatorCallExpr (
152
230
hasOverloadedOperatorName (" =" ),
153
- callee (cxxMethodDecl (ofClass (optionalClass ()))),
231
+ callee (cxxMethodDecl (ofClass (optionalOrDerivedClass ()))),
154
232
unless (hasDeclaration (cxxMethodDecl (
155
233
anyOf (isCopyAssignmentOperator (), isMoveAssignmentOperator ())))),
156
234
argumentCountIs (2 ), hasArgument (1 , unless (hasNulloptType ())));
157
235
}
158
236
159
237
auto isOptionalNulloptAssignment () {
160
- return cxxOperatorCallExpr (hasOverloadedOperatorName ( " = " ),
161
- callee ( cxxMethodDecl ( ofClass ( optionalClass ())) ),
162
- argumentCountIs ( 2 ),
163
- hasArgument (1 , hasNulloptType ()));
238
+ return cxxOperatorCallExpr (
239
+ hasOverloadedOperatorName ( " = " ),
240
+ callee ( cxxMethodDecl ( ofClass ( optionalOrDerivedClass ())) ),
241
+ argumentCountIs ( 2 ), hasArgument (1 , hasNulloptType ()));
164
242
}
165
243
166
244
auto isStdSwapCall () {
167
245
return callExpr (callee (functionDecl (hasName (" std::swap" ))),
168
- argumentCountIs (2 ), hasArgument (0 , hasOptionalType ()),
169
- hasArgument (1 , hasOptionalType ()));
246
+ argumentCountIs (2 ),
247
+ hasArgument (0 , hasOptionalOrDerivedType ()),
248
+ hasArgument (1 , hasOptionalOrDerivedType ()));
170
249
}
171
250
172
251
auto isStdForwardCall () {
173
252
return callExpr (callee (functionDecl (hasName (" std::forward" ))),
174
- argumentCountIs (1 ), hasArgument (0 , hasOptionalType ()));
253
+ argumentCountIs (1 ),
254
+ hasArgument (0 , hasOptionalOrDerivedType ()));
175
255
}
176
256
177
257
constexpr llvm::StringLiteral ValueOrCallID = " ValueOrCall" ;
@@ -212,8 +292,9 @@ auto isValueOrNotEqX() {
212
292
}
213
293
214
294
auto isCallReturningOptional () {
215
- return callExpr (hasType (qualType (anyOf (
216
- optionalOrAliasType (), referenceType (pointee (optionalOrAliasType ()))))));
295
+ return callExpr (hasType (qualType (
296
+ anyOf (desugarsToOptionalOrDerivedType (),
297
+ referenceType (pointee (desugarsToOptionalOrDerivedType ()))))));
217
298
}
218
299
219
300
template <typename L, typename R>
@@ -275,28 +356,23 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
275
356
return HasValueVal;
276
357
}
277
358
278
- // / Returns true if and only if `Type` is an optional type.
279
- bool isOptionalType (QualType Type) {
280
- if (!Type->isRecordType ())
281
- return false ;
282
- const CXXRecordDecl *D = Type->getAsCXXRecordDecl ();
283
- return D != nullptr && hasOptionalClassName (*D);
359
+ QualType valueTypeFromOptionalDecl (const CXXRecordDecl &RD) {
360
+ auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
361
+ return CTSD.getTemplateArgs ()[0 ].getAsType ();
284
362
}
285
363
286
364
// / Returns the number of optional wrappers in `Type`.
287
365
// /
288
366
// / For example, if `Type` is `optional<optional<int>>`, the result of this
289
367
// / function will be 2.
290
368
int countOptionalWrappers (const ASTContext &ASTCtx, QualType Type) {
291
- if (!isOptionalType (Type))
369
+ const CXXRecordDecl *Optional =
370
+ getOptionalBaseClass (Type->getAsCXXRecordDecl ());
371
+ if (Optional == nullptr )
292
372
return 0 ;
293
373
return 1 + countOptionalWrappers (
294
374
ASTCtx,
295
- cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl ())
296
- ->getTemplateArgs ()
297
- .get (0 )
298
- .getAsType ()
299
- .getDesugaredType (ASTCtx));
375
+ valueTypeFromOptionalDecl (*Optional).getDesugaredType (ASTCtx));
300
376
}
301
377
302
378
StorageLocation *getLocBehindPossiblePointer (const Expr &E,
@@ -843,13 +919,7 @@ auto buildDiagnoseMatchSwitch(
843
919
844
920
ast_matchers::DeclarationMatcher
845
921
UncheckedOptionalAccessModel::optionalClassDecl () {
846
- return optionalClass ();
847
- }
848
-
849
- static QualType valueTypeFromOptionalType (QualType OptionalTy) {
850
- auto *CTSD =
851
- cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl ());
852
- return CTSD->getTemplateArgs ()[0 ].getAsType ();
922
+ return cxxRecordDecl (optionalClass ());
853
923
}
854
924
855
925
UncheckedOptionalAccessModel::UncheckedOptionalAccessModel (ASTContext &Ctx,
@@ -858,9 +928,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
858
928
TransferMatchSwitch (buildTransferMatchSwitch()) {
859
929
Env.getDataflowAnalysisContext ().setSyntheticFieldCallback (
860
930
[&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
861
- if (!isOptionalType (Ty))
931
+ const CXXRecordDecl *Optional =
932
+ getOptionalBaseClass (Ty->getAsCXXRecordDecl ());
933
+ if (Optional == nullptr )
862
934
return {};
863
- return {{" value" , valueTypeFromOptionalType (Ty )},
935
+ return {{" value" , valueTypeFromOptionalDecl (*Optional )},
864
936
{" has_value" , Ctx.BoolTy }};
865
937
});
866
938
}
0 commit comments