Skip to content

Commit 28d2bef

Browse files
authored
Merge pull request #80946 from DougGregor/safe-nested-in-unsafe-fixes-6.2
[6.2] [Strict memory safety] Improve handling of safe types nested within unsafe ones
2 parents 61c4a03 + 19975bc commit 28d2bef

33 files changed

+579
-388
lines changed

include/swift/AST/Types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ class RecursiveTypeProperties {
318318
Bits &= ~HasDependentMember;
319319
}
320320

321+
/// Remove the IsUnsafe property from this set.
322+
void removeIsUnsafe() {
323+
Bits &= ~IsUnsafe;
324+
}
325+
321326
/// Test for a particular property in this set.
322327
bool operator&(Property prop) const {
323328
return Bits & prop;

lib/AST/ASTContext.cpp

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,6 +3543,12 @@ TypeAliasType *TypeAliasType::get(TypeAliasDecl *typealias, Type parent,
35433543
auto &ctx = underlying->getASTContext();
35443544
auto arena = getArena(properties);
35453545

3546+
// Typealiases can't meaningfully be unsafe; it's the underlying type that
3547+
// matters.
3548+
properties.removeIsUnsafe();
3549+
if (underlying->isUnsafe())
3550+
properties |= RecursiveTypeProperties::IsUnsafe;
3551+
35463552
// Profile the type.
35473553
llvm::FoldingSetNodeID id;
35483554
TypeAliasType::Profile(id, typealias, parent, genericArgs, underlying);
@@ -4190,6 +4196,54 @@ void UnboundGenericType::Profile(llvm::FoldingSetNodeID &ID,
41904196
ID.AddPointer(Parent.getPointer());
41914197
}
41924198

4199+
/// The safety of a parent type does not have an impact on a nested type within
4200+
/// it. This produces the recursive properties of a given type that should
4201+
/// be propagated to a nested type, which won't include any "IsUnsafe" bit
4202+
/// determined based on the declaration itself.
4203+
static RecursiveTypeProperties getRecursivePropertiesAsParent(Type type) {
4204+
if (!type)
4205+
return RecursiveTypeProperties();
4206+
4207+
// We only need to do anything interesting at all for unsafe types.
4208+
auto properties = type->getRecursiveProperties();
4209+
if (!properties.isUnsafe())
4210+
return properties;
4211+
4212+
if (auto nominal = type->getAnyNominal()) {
4213+
// If the nominal wasn't itself unsafe, then we got the unsafety from
4214+
// something else (e.g., a generic argument), so it won't change.
4215+
if (nominal->getExplicitSafety() != ExplicitSafety::Unsafe)
4216+
return properties;
4217+
}
4218+
4219+
// Drop the "unsafe" bit. We have to recompute it without considering the
4220+
// enclosing nominal type.
4221+
properties.removeIsUnsafe();
4222+
4223+
// Check generic arguments of parent types.
4224+
while (type) {
4225+
// Merge from the generic arguments.
4226+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
4227+
for (auto genericArg : boundGeneric->getGenericArgs())
4228+
properties |= genericArg->getRecursiveProperties();
4229+
}
4230+
4231+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>()) {
4232+
type = nominalOrBound->getParent();
4233+
continue;
4234+
}
4235+
4236+
if (auto unbound = type->getAs<UnboundGenericType>()) {
4237+
type = unbound->getParent();
4238+
continue;
4239+
}
4240+
4241+
break;
4242+
};
4243+
4244+
return properties;
4245+
}
4246+
41934247
UnboundGenericType *UnboundGenericType::
41944248
get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41954249
llvm::FoldingSetNodeID ID;
@@ -4198,7 +4252,7 @@ get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41984252
RecursiveTypeProperties properties;
41994253
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42004254
properties |= RecursiveTypeProperties::IsUnsafe;
4201-
if (Parent) properties |= Parent->getRecursiveProperties();
4255+
properties |= getRecursivePropertiesAsParent(Parent);
42024256

42034257
auto arena = getArena(properties);
42044258

@@ -4252,7 +4306,7 @@ BoundGenericType *BoundGenericType::get(NominalTypeDecl *TheDecl,
42524306
RecursiveTypeProperties properties;
42534307
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42544308
properties |= RecursiveTypeProperties::IsUnsafe;
4255-
if (Parent) properties |= Parent->getRecursiveProperties();
4309+
properties |= getRecursivePropertiesAsParent(Parent);
42564310
for (Type Arg : GenericArgs) {
42574311
properties |= Arg->getRecursiveProperties();
42584312
}
@@ -4335,7 +4389,7 @@ EnumType *EnumType::get(EnumDecl *D, Type Parent, const ASTContext &C) {
43354389
RecursiveTypeProperties properties;
43364390
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43374391
properties |= RecursiveTypeProperties::IsUnsafe;
4338-
if (Parent) properties |= Parent->getRecursiveProperties();
4392+
properties |= getRecursivePropertiesAsParent(Parent);
43394393
auto arena = getArena(properties);
43404394

43414395
auto *&known = C.getImpl().getArena(arena).EnumTypes[{D, Parent}];
@@ -4353,7 +4407,7 @@ StructType *StructType::get(StructDecl *D, Type Parent, const ASTContext &C) {
43534407
RecursiveTypeProperties properties;
43544408
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43554409
properties |= RecursiveTypeProperties::IsUnsafe;
4356-
if (Parent) properties |= Parent->getRecursiveProperties();
4410+
properties |= getRecursivePropertiesAsParent(Parent);
43574411
auto arena = getArena(properties);
43584412

43594413
auto *&known = C.getImpl().getArena(arena).StructTypes[{D, Parent}];
@@ -4371,7 +4425,7 @@ ClassType *ClassType::get(ClassDecl *D, Type Parent, const ASTContext &C) {
43714425
RecursiveTypeProperties properties;
43724426
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43734427
properties |= RecursiveTypeProperties::IsUnsafe;
4374-
if (Parent) properties |= Parent->getRecursiveProperties();
4428+
properties |= getRecursivePropertiesAsParent(Parent);
43754429
auto arena = getArena(properties);
43764430

43774431
auto *&known = C.getImpl().getArena(arena).ClassTypes[{D, Parent}];
@@ -5538,7 +5592,7 @@ ProtocolType *ProtocolType::get(ProtocolDecl *D, Type Parent,
55385592
RecursiveTypeProperties properties;
55395593
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
55405594
properties |= RecursiveTypeProperties::IsUnsafe;
5541-
if (Parent) properties |= Parent->getRecursiveProperties();
5595+
properties |= getRecursivePropertiesAsParent(Parent);
55425596
auto arena = getArena(properties);
55435597

55445598
auto *&known = C.getImpl().getArena(arena).ProtocolTypes[{D, Parent}];

lib/AST/Decl.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,12 +1212,14 @@ ExplicitSafety Decl::getExplicitSafety() const {
12121212
ExplicitSafety::Unspecified);
12131213
}
12141214

1215-
// Inference: Check the enclosing context.
1216-
if (auto enclosingDC = getDeclContext()) {
1217-
// Is this an extension with @safe or @unsafe on it?
1218-
if (auto ext = dyn_cast<ExtensionDecl>(enclosingDC)) {
1219-
if (auto extSafety = getExplicitSafetyFromAttrs(ext))
1220-
return *extSafety;
1215+
// Inference: Check the enclosing context, unless this is a type.
1216+
if (!isa<TypeDecl>(this)) {
1217+
if (auto enclosingDC = getDeclContext()) {
1218+
// Is this an extension with @safe or @unsafe on it?
1219+
if (auto ext = dyn_cast<ExtensionDecl>(enclosingDC)) {
1220+
if (auto extSafety = getExplicitSafetyFromAttrs(ext))
1221+
return *extSafety;
1222+
}
12211223
}
12221224
}
12231225

lib/Sema/TypeCheckEffects.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ class EffectsHandlingWalker : public ASTWalker {
670670
recurse = asImpl().checkForEach(forEach);
671671
} else if (auto labeled = dyn_cast<LabeledConditionalStmt>(S)) {
672672
asImpl().noteLabeledConditionalStmt(labeled);
673+
} else if (auto defer = dyn_cast<DeferStmt>(S)) {
674+
recurse = asImpl().checkDefer(defer);
673675
}
674676

675677
if (!recurse)
@@ -2110,6 +2112,10 @@ class ApplyClassifier {
21102112
return ShouldRecurse;
21112113
}
21122114

2115+
ShouldRecurse_t checkDefer(DeferStmt *S) {
2116+
return ShouldNotRecurse;
2117+
}
2118+
21132119
ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
21142120
return ShouldRecurse;
21152121
}
@@ -2255,6 +2261,10 @@ class ApplyClassifier {
22552261
return ShouldRecurse;
22562262
}
22572263

2264+
ShouldRecurse_t checkDefer(DeferStmt *S) {
2265+
return ShouldNotRecurse;
2266+
}
2267+
22582268
ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
22592269
return ShouldRecurse;
22602270
}
@@ -2354,6 +2364,10 @@ class ApplyClassifier {
23542364
return ShouldNotRecurse;
23552365
}
23562366

2367+
ShouldRecurse_t checkDefer(DeferStmt *S) {
2368+
return ShouldNotRecurse;
2369+
}
2370+
23572371
ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
23582372
return ShouldRecurse;
23592373
}
@@ -4398,6 +4412,17 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
43984412
return ShouldRecurse;
43994413
}
44004414

4415+
ShouldRecurse_t checkDefer(DeferStmt *S) {
4416+
// Pretend we're in an 'unsafe'.
4417+
ContextScope scope(*this, std::nullopt);
4418+
scope.enterUnsafe(S->getDeferLoc());
4419+
4420+
// Walk the call expression. We don't care about the rest.
4421+
S->getCallExpr()->walk(*this);
4422+
4423+
return ShouldNotRecurse;
4424+
}
4425+
44014426
void diagnoseRedundantTry(AnyTryExpr *E) const {
44024427
if (auto *SVE = SingleValueStmtExpr::tryDigOutSingleValueStmtExpr(E)) {
44034428
// For an if/switch expression, produce a tailored warning.

lib/Sema/TypeCheckUnsafe.cpp

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,12 @@ bool swift::enumerateUnsafeUses(ArrayRef<ProtocolConformanceRef> conformances,
330330
bool swift::enumerateUnsafeUses(SubstitutionMap subs,
331331
SourceLoc loc,
332332
llvm::function_ref<bool(UnsafeUse)> fn) {
333-
// FIXME: Check replacement types?
333+
// Replacement types.
334+
for (auto replacementType : subs.getReplacementTypes()) {
335+
if (replacementType->isUnsafe() &&
336+
fn(UnsafeUse::forReferenceToUnsafe(nullptr, false, replacementType, loc)))
337+
return true;
338+
}
334339

335340
// Check conformances.
336341
if (enumerateUnsafeUses(subs.getConformances(), loc, fn))
@@ -375,21 +380,73 @@ void swift::diagnoseUnsafeType(ASTContext &ctx, SourceLoc loc, Type type,
375380
if (!ctx.LangOpts.hasFeature(Feature::StrictMemorySafety))
376381
return;
377382

378-
if (!type->isUnsafe() && !type->getCanonicalType()->isUnsafe())
383+
if (!type->isUnsafe())
379384
return;
380385

381-
// Look for a specific @unsafe nominal type.
382-
Type specificType;
383-
type.findIf([&specificType](Type type) {
384-
if (auto typeDecl = type->getAnyNominal()) {
385-
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
386-
specificType = type;
387-
return false;
386+
// Look for a specific @unsafe nominal type along the way.
387+
class Walker : public TypeWalker {
388+
public:
389+
Type specificType;
390+
391+
Action walkToTypePre(Type type) override {
392+
if (specificType)
393+
return Action::Stop;
394+
395+
// If this refers to a nominal type that is @unsafe, store that.
396+
if (auto typeDecl = type->getAnyNominal()) {
397+
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
398+
specificType = type;
399+
return Action::Stop;
400+
}
401+
}
402+
403+
// Do not recurse into nominal types, because we do not want to visit
404+
// their "parent" types.
405+
if (isa<NominalOrBoundGenericNominalType>(type.getPointer()) ||
406+
isa<UnboundGenericType>(type.getPointer())) {
407+
// Recurse into the generic arguments. This operation is recursive,
408+
// because we also need to see the generic arguments of parent types.
409+
walkGenericArguments(type);
410+
411+
return Action::SkipNode;
412+
}
413+
414+
return Action::Continue;
415+
}
416+
417+
private:
418+
/// Recursively walk the generic arguments of this type and its parent
419+
/// types.
420+
void walkGenericArguments(Type type) {
421+
if (!type)
422+
return;
423+
424+
// Walk the generic arguments.
425+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
426+
for (auto genericArg : boundGeneric->getGenericArgs())
427+
genericArg.walk(*this);
388428
}
429+
430+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>())
431+
return walkGenericArguments(nominalOrBound->getParent());
432+
433+
if (auto unbound = type->getAs<UnboundGenericType>())
434+
return walkGenericArguments(unbound->getParent());
389435
}
436+
};
390437

391-
return false;
392-
});
438+
// Look for a canonical unsafe type.
439+
Walker walker;
440+
type->getCanonicalType().walk(walker);
441+
Type specificType = walker.specificType;
442+
443+
// Look for an unsafe type in the non-canonical type, which is a better answer
444+
// if we can find it.
445+
walker.specificType = Type();
446+
type.walk(walker);
447+
if (specificType && walker.specificType &&
448+
specificType->isEqual(walker.specificType))
449+
specificType = walker.specificType;
393450

394451
diagnose(specificType ? specificType : type);
395452
}

0 commit comments

Comments
 (0)