Skip to content

Commit 209aac4

Browse files
committed
[clang][dataflow] Make optional checker work for types derived from optional.
`llvm::MaybeAlign` does this, for example. It's not an option to simply ignore these derived classes because they get cast back to the optional classes (for example, simply when calling the optional member functions), and our transfer functions will then run on those optional classes and therefore require them to be properly initialized.
1 parent 6cd68c2 commit 209aac4

File tree

2 files changed

+183
-51
lines changed

2 files changed

+183
-51
lines changed

clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp

Lines changed: 123 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -64,39 +64,117 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
6464
return false;
6565
}
6666

67+
static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
68+
if (RD == nullptr)
69+
return nullptr;
70+
if (hasOptionalClassName(*RD))
71+
return RD;
72+
73+
if (!RD->hasDefinition())
74+
return nullptr;
75+
76+
for (const CXXBaseSpecifier &Base : RD->bases())
77+
if (const CXXRecordDecl *BaseClass =
78+
getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
79+
return BaseClass;
80+
81+
return nullptr;
82+
}
83+
6784
namespace {
6885

6986
using namespace ::clang::ast_matchers;
7087
using LatticeTransferState = TransferState<NoopLattice>;
7188

72-
AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
73-
return hasOptionalClassName(Node);
89+
AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }
90+
91+
AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
92+
return getOptionalBaseClass(&Node) != nullptr;
7493
}
7594

76-
DeclarationMatcher optionalClass() {
77-
return classTemplateSpecializationDecl(
78-
hasOptionalClassNameMatcher(),
79-
hasTemplateArgument(0, refersToType(type().bind("T"))));
95+
auto desugarsToOptionalType() {
96+
return hasUnqualifiedDesugaredType(
97+
recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
8098
}
8199

82-
auto optionalOrAliasType() {
100+
auto desugarsToOptionalOrDerivedType() {
83101
return hasUnqualifiedDesugaredType(
84-
recordType(hasDeclaration(optionalClass())));
102+
recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
85103
}
86104

87-
/// Matches any of the spellings of the optional types and sugar, aliases, etc.
88-
auto hasOptionalType() { return hasType(optionalOrAliasType()); }
105+
auto hasOptionalType() { return hasType(desugarsToOptionalType()); }
106+
107+
/// Matches any of the spellings of the optional types and sugar, aliases,
108+
/// derived classes, etc.
109+
auto hasOptionalOrDerivedType() {
110+
return hasType(desugarsToOptionalOrDerivedType());
111+
}
112+
113+
QualType getPublicType(const Expr *E) {
114+
auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
115+
if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
116+
QualType Ty = E->getType();
117+
if (Ty->isPointerType())
118+
return Ty->getPointeeType();
119+
return Ty;
120+
}
121+
122+
QualType Ty = getPublicType(Cast->getSubExpr());
123+
124+
// Is `Ty` the type of `*this`? In this special case, we can upcast to the
125+
// base class even if the base is non-public.
126+
bool TyIsThisType = isa<CXXThisExpr>(Cast->getSubExpr());
127+
128+
for (const CXXBaseSpecifier *Base : Cast->path()) {
129+
if (Base->getAccessSpecifier() != AS_public && !TyIsThisType)
130+
break;
131+
Ty = Base->getType();
132+
TyIsThisType = false;
133+
}
134+
135+
return Ty;
136+
}
137+
138+
// Returns the least-derived type for the receiver of `MCE` that
139+
// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
140+
// Effectively, we upcast until we reach a non-public base class, unless that
141+
// base is a base of `*this`.
142+
//
143+
// This is needed to correctly match methods called on types derived from
144+
// `std::optional`.
145+
//
146+
// Say we have a `struct Derived : public std::optional<int> {} d;` For a call
147+
// `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
148+
//
149+
// ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
150+
// | <UncheckedDerivedToBase (optional -> __optional_storage_base)>
151+
// `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
152+
//
153+
// The type of the implicit object argument is `__optional_storage_base`
154+
// (since this is the internal type that `has_value()` is declared on). If we
155+
// call `IgnoreParenImpCasts()` on the implicit object argument, we get the
156+
// `DeclRefExpr`, which has type `Derived`. Neither of these types is
157+
// `optional`, and hence neither is sufficient for querying whether we are
158+
// calling a method on `optional`.
159+
//
160+
// Instead, starting with the most derived type, we need to follow the chain of
161+
// casts
162+
QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) {
163+
return getPublicType(MCE.getImplicitObjectArgument());
164+
}
165+
166+
AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType,
167+
ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
168+
return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder);
169+
}
89170

90171
auto isOptionalMemberCallWithNameMatcher(
91172
ast_matchers::internal::Matcher<NamedDecl> matcher,
92173
const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
93-
auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
94-
: cxxThisExpr());
95-
return cxxMemberCallExpr(
96-
on(expr(Exception,
97-
anyOf(hasOptionalType(),
98-
hasType(pointerType(pointee(optionalOrAliasType())))))),
99-
callee(cxxMethodDecl(matcher)));
174+
return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable)))
175+
: anything(),
176+
publicReceiverType(desugarsToOptionalType()),
177+
callee(cxxMethodDecl(matcher)));
100178
}
101179

102180
auto isOptionalOperatorCallWithName(
@@ -129,19 +207,19 @@ auto inPlaceClass() {
129207

130208
auto isOptionalNulloptConstructor() {
131209
return cxxConstructExpr(
132-
hasOptionalType(),
210+
hasOptionalOrDerivedType(),
133211
hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
134212
hasParameter(0, hasNulloptType()))));
135213
}
136214

137215
auto isOptionalInPlaceConstructor() {
138-
return cxxConstructExpr(hasOptionalType(),
216+
return cxxConstructExpr(hasOptionalOrDerivedType(),
139217
hasArgument(0, hasType(inPlaceClass())));
140218
}
141219

142220
auto isOptionalValueOrConversionConstructor() {
143221
return cxxConstructExpr(
144-
hasOptionalType(),
222+
hasOptionalOrDerivedType(),
145223
unless(hasDeclaration(
146224
cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
147225
argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
@@ -150,28 +228,30 @@ auto isOptionalValueOrConversionConstructor() {
150228
auto isOptionalValueOrConversionAssignment() {
151229
return cxxOperatorCallExpr(
152230
hasOverloadedOperatorName("="),
153-
callee(cxxMethodDecl(ofClass(optionalClass()))),
231+
callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
154232
unless(hasDeclaration(cxxMethodDecl(
155233
anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
156234
argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
157235
}
158236

159237
auto isOptionalNulloptAssignment() {
160-
return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
161-
callee(cxxMethodDecl(ofClass(optionalClass()))),
162-
argumentCountIs(2),
163-
hasArgument(1, hasNulloptType()));
238+
return cxxOperatorCallExpr(
239+
hasOverloadedOperatorName("="),
240+
callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
241+
argumentCountIs(2), hasArgument(1, hasNulloptType()));
164242
}
165243

166244
auto isStdSwapCall() {
167245
return callExpr(callee(functionDecl(hasName("std::swap"))),
168-
argumentCountIs(2), hasArgument(0, hasOptionalType()),
169-
hasArgument(1, hasOptionalType()));
246+
argumentCountIs(2),
247+
hasArgument(0, hasOptionalOrDerivedType()),
248+
hasArgument(1, hasOptionalOrDerivedType()));
170249
}
171250

172251
auto isStdForwardCall() {
173252
return callExpr(callee(functionDecl(hasName("std::forward"))),
174-
argumentCountIs(1), hasArgument(0, hasOptionalType()));
253+
argumentCountIs(1),
254+
hasArgument(0, hasOptionalOrDerivedType()));
175255
}
176256

177257
constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
@@ -212,8 +292,9 @@ auto isValueOrNotEqX() {
212292
}
213293

214294
auto isCallReturningOptional() {
215-
return callExpr(hasType(qualType(anyOf(
216-
optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
295+
return callExpr(hasType(qualType(
296+
anyOf(desugarsToOptionalOrDerivedType(),
297+
referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
217298
}
218299

219300
template <typename L, typename R>
@@ -275,28 +356,23 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
275356
return HasValueVal;
276357
}
277358

278-
/// Returns true if and only if `Type` is an optional type.
279-
bool isOptionalType(QualType Type) {
280-
if (!Type->isRecordType())
281-
return false;
282-
const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
283-
return D != nullptr && hasOptionalClassName(*D);
359+
QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
360+
auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
361+
return CTSD.getTemplateArgs()[0].getAsType();
284362
}
285363

286364
/// Returns the number of optional wrappers in `Type`.
287365
///
288366
/// For example, if `Type` is `optional<optional<int>>`, the result of this
289367
/// function will be 2.
290368
int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
291-
if (!isOptionalType(Type))
369+
const CXXRecordDecl *Optional =
370+
getOptionalBaseClass(Type->getAsCXXRecordDecl());
371+
if (Optional == nullptr)
292372
return 0;
293373
return 1 + countOptionalWrappers(
294374
ASTCtx,
295-
cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
296-
->getTemplateArgs()
297-
.get(0)
298-
.getAsType()
299-
.getDesugaredType(ASTCtx));
375+
valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
300376
}
301377

302378
StorageLocation *getLocBehindPossiblePointer(const Expr &E,
@@ -843,13 +919,7 @@ auto buildDiagnoseMatchSwitch(
843919

844920
ast_matchers::DeclarationMatcher
845921
UncheckedOptionalAccessModel::optionalClassDecl() {
846-
return optionalClass();
847-
}
848-
849-
static QualType valueTypeFromOptionalType(QualType OptionalTy) {
850-
auto *CTSD =
851-
cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
852-
return CTSD->getTemplateArgs()[0].getAsType();
922+
return cxxRecordDecl(optionalClass());
853923
}
854924

855925
UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
@@ -858,9 +928,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
858928
TransferMatchSwitch(buildTransferMatchSwitch()) {
859929
Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
860930
[&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
861-
if (!isOptionalType(Ty))
931+
const CXXRecordDecl *Optional =
932+
getOptionalBaseClass(Ty->getAsCXXRecordDecl());
933+
if (Optional == nullptr)
862934
return {};
863-
return {{"value", valueTypeFromOptionalType(Ty)},
935+
return {{"value", valueTypeFromOptionalDecl(*Optional)},
864936
{"has_value", Ctx.BoolTy}};
865937
});
866938
}

clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,6 +3383,66 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) {
33833383
}
33843384
)");
33853385
}
3386+
3387+
TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) {
3388+
ExpectDiagnosticsFor(R"(
3389+
#include "unchecked_optional_access_test.h"
3390+
3391+
struct Derived : public $ns::$optional<int> {};
3392+
3393+
void target(Derived opt) {
3394+
*opt; // [[unsafe]]
3395+
if (opt.has_value())
3396+
*opt;
3397+
3398+
// The same thing, but with a pointer receiver.
3399+
Derived *popt = &opt;
3400+
**popt; // [[unsafe]]
3401+
if (popt->has_value())
3402+
**popt;
3403+
}
3404+
)");
3405+
}
3406+
3407+
TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) {
3408+
ExpectDiagnosticsFor(R"(
3409+
#include "unchecked_optional_access_test.h"
3410+
3411+
template <class T>
3412+
struct Derived : public $ns::$optional<T> {};
3413+
3414+
void target(Derived<int> opt) {
3415+
*opt; // [[unsafe]]
3416+
if (opt.has_value())
3417+
*opt;
3418+
3419+
// The same thing, but with a pointer receiver.
3420+
Derived<int> *popt = &opt;
3421+
**popt; // [[unsafe]]
3422+
if (popt->has_value())
3423+
**popt;
3424+
}
3425+
)");
3426+
}
3427+
3428+
TEST_P(UncheckedOptionalAccessTest, ClassDerivedPrivatelyFromOptional) {
3429+
// Classes that derive privately from optional can themselves still call
3430+
// member functions of optional. Check that we model the optional correctly
3431+
// in this situation.
3432+
ExpectDiagnosticsFor(R"(
3433+
#include "unchecked_optional_access_test.h"
3434+
3435+
struct Derived : private $ns::$optional<int> {
3436+
void Method() {
3437+
**this; // [[unsafe]]
3438+
if (this->has_value())
3439+
**this;
3440+
}
3441+
};
3442+
)",
3443+
ast_matchers::hasName("Method"));
3444+
}
3445+
33863446
// FIXME: Add support for:
33873447
// - constructors (copy, move)
33883448
// - assignment operators (default, copy, move)

0 commit comments

Comments
 (0)