Skip to content

Commit 54e4400

Browse files
authored
Merge pull request #80933 from DougGregor/safe-nested-in-unsafe-fixes
[Strict memory safety] Improve handling of safe types nested within unsafe ones
2 parents 3718a15 + 1c9875e commit 54e4400

33 files changed

+579
-388
lines changed

include/swift/AST/Types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ class RecursiveTypeProperties {
318318
Bits &= ~HasDependentMember;
319319
}
320320

321+
/// Remove the IsUnsafe property from this set.
322+
void removeIsUnsafe() {
323+
Bits &= ~IsUnsafe;
324+
}
325+
321326
/// Test for a particular property in this set.
322327
bool operator&(Property prop) const {
323328
return Bits & prop;

lib/AST/ASTContext.cpp

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,6 +3543,12 @@ TypeAliasType *TypeAliasType::get(TypeAliasDecl *typealias, Type parent,
35433543
auto &ctx = underlying->getASTContext();
35443544
auto arena = getArena(properties);
35453545

3546+
// Typealiases can't meaningfully be unsafe; it's the underlying type that
3547+
// matters.
3548+
properties.removeIsUnsafe();
3549+
if (underlying->isUnsafe())
3550+
properties |= RecursiveTypeProperties::IsUnsafe;
3551+
35463552
// Profile the type.
35473553
llvm::FoldingSetNodeID id;
35483554
TypeAliasType::Profile(id, typealias, parent, genericArgs, underlying);
@@ -4190,6 +4196,54 @@ void UnboundGenericType::Profile(llvm::FoldingSetNodeID &ID,
41904196
ID.AddPointer(Parent.getPointer());
41914197
}
41924198

4199+
/// The safety of a parent type does not have an impact on a nested type within
4200+
/// it. This produces the recursive properties of a given type that should
4201+
/// be propagated to a nested type, which won't include any "IsUnsafe" bit
4202+
/// determined based on the declaration itself.
4203+
static RecursiveTypeProperties getRecursivePropertiesAsParent(Type type) {
4204+
if (!type)
4205+
return RecursiveTypeProperties();
4206+
4207+
// We only need to do anything interesting at all for unsafe types.
4208+
auto properties = type->getRecursiveProperties();
4209+
if (!properties.isUnsafe())
4210+
return properties;
4211+
4212+
if (auto nominal = type->getAnyNominal()) {
4213+
// If the nominal wasn't itself unsafe, then we got the unsafety from
4214+
// something else (e.g., a generic argument), so it won't change.
4215+
if (nominal->getExplicitSafety() != ExplicitSafety::Unsafe)
4216+
return properties;
4217+
}
4218+
4219+
// Drop the "unsafe" bit. We have to recompute it without considering the
4220+
// enclosing nominal type.
4221+
properties.removeIsUnsafe();
4222+
4223+
// Check generic arguments of parent types.
4224+
while (type) {
4225+
// Merge from the generic arguments.
4226+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
4227+
for (auto genericArg : boundGeneric->getGenericArgs())
4228+
properties |= genericArg->getRecursiveProperties();
4229+
}
4230+
4231+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>()) {
4232+
type = nominalOrBound->getParent();
4233+
continue;
4234+
}
4235+
4236+
if (auto unbound = type->getAs<UnboundGenericType>()) {
4237+
type = unbound->getParent();
4238+
continue;
4239+
}
4240+
4241+
break;
4242+
};
4243+
4244+
return properties;
4245+
}
4246+
41934247
UnboundGenericType *UnboundGenericType::
41944248
get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41954249
llvm::FoldingSetNodeID ID;
@@ -4198,7 +4252,7 @@ get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41984252
RecursiveTypeProperties properties;
41994253
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42004254
properties |= RecursiveTypeProperties::IsUnsafe;
4201-
if (Parent) properties |= Parent->getRecursiveProperties();
4255+
properties |= getRecursivePropertiesAsParent(Parent);
42024256

42034257
auto arena = getArena(properties);
42044258

@@ -4252,7 +4306,7 @@ BoundGenericType *BoundGenericType::get(NominalTypeDecl *TheDecl,
42524306
RecursiveTypeProperties properties;
42534307
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42544308
properties |= RecursiveTypeProperties::IsUnsafe;
4255-
if (Parent) properties |= Parent->getRecursiveProperties();
4309+
properties |= getRecursivePropertiesAsParent(Parent);
42564310
for (Type Arg : GenericArgs) {
42574311
properties |= Arg->getRecursiveProperties();
42584312
}
@@ -4335,7 +4389,7 @@ EnumType *EnumType::get(EnumDecl *D, Type Parent, const ASTContext &C) {
43354389
RecursiveTypeProperties properties;
43364390
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43374391
properties |= RecursiveTypeProperties::IsUnsafe;
4338-
if (Parent) properties |= Parent->getRecursiveProperties();
4392+
properties |= getRecursivePropertiesAsParent(Parent);
43394393
auto arena = getArena(properties);
43404394

43414395
auto *&known = C.getImpl().getArena(arena).EnumTypes[{D, Parent}];
@@ -4353,7 +4407,7 @@ StructType *StructType::get(StructDecl *D, Type Parent, const ASTContext &C) {
43534407
RecursiveTypeProperties properties;
43544408
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43554409
properties |= RecursiveTypeProperties::IsUnsafe;
4356-
if (Parent) properties |= Parent->getRecursiveProperties();
4410+
properties |= getRecursivePropertiesAsParent(Parent);
43574411
auto arena = getArena(properties);
43584412

43594413
auto *&known = C.getImpl().getArena(arena).StructTypes[{D, Parent}];
@@ -4371,7 +4425,7 @@ ClassType *ClassType::get(ClassDecl *D, Type Parent, const ASTContext &C) {
43714425
RecursiveTypeProperties properties;
43724426
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43734427
properties |= RecursiveTypeProperties::IsUnsafe;
4374-
if (Parent) properties |= Parent->getRecursiveProperties();
4428+
properties |= getRecursivePropertiesAsParent(Parent);
43754429
auto arena = getArena(properties);
43764430

43774431
auto *&known = C.getImpl().getArena(arena).ClassTypes[{D, Parent}];
@@ -5538,7 +5592,7 @@ ProtocolType *ProtocolType::get(ProtocolDecl *D, Type Parent,
55385592
RecursiveTypeProperties properties;
55395593
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
55405594
properties |= RecursiveTypeProperties::IsUnsafe;
5541-
if (Parent) properties |= Parent->getRecursiveProperties();
5595+
properties |= getRecursivePropertiesAsParent(Parent);
55425596
auto arena = getArena(properties);
55435597

55445598
auto *&known = C.getImpl().getArena(arena).ProtocolTypes[{D, Parent}];

lib/AST/Decl.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,12 +1216,14 @@ ExplicitSafety Decl::getExplicitSafety() const {
12161216
ExplicitSafety::Unspecified);
12171217
}
12181218

1219-
// Inference: Check the enclosing context.
1220-
if (auto enclosingDC = getDeclContext()) {
1221-
// Is this an extension with @safe or @unsafe on it?
1222-
if (auto ext = dyn_cast<ExtensionDecl>(enclosingDC)) {
1223-
if (auto extSafety = getExplicitSafetyFromAttrs(ext))
1224-
return *extSafety;
1219+
// Inference: Check the enclosing context, unless this is a type.
1220+
if (!isa<TypeDecl>(this)) {
1221+
if (auto enclosingDC = getDeclContext()) {
1222+
// Is this an extension with @safe or @unsafe on it?
1223+
if (auto ext = dyn_cast<ExtensionDecl>(enclosingDC)) {
1224+
if (auto extSafety = getExplicitSafetyFromAttrs(ext))
1225+
return *extSafety;
1226+
}
12251227
}
12261228
}
12271229

lib/Sema/TypeCheckEffects.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ class EffectsHandlingWalker : public ASTWalker {
670670
recurse = asImpl().checkForEach(forEach);
671671
} else if (auto labeled = dyn_cast<LabeledConditionalStmt>(S)) {
672672
asImpl().noteLabeledConditionalStmt(labeled);
673+
} else if (auto defer = dyn_cast<DeferStmt>(S)) {
674+
recurse = asImpl().checkDefer(defer);
673675
}
674676

675677
if (!recurse)
@@ -2106,6 +2108,10 @@ class ApplyClassifier {
21062108
return ShouldRecurse;
21072109
}
21082110

2111+
ShouldRecurse_t checkDefer(DeferStmt *S) {
2112+
return ShouldNotRecurse;
2113+
}
2114+
21092115
ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
21102116
return ShouldRecurse;
21112117
}
@@ -2251,6 +2257,10 @@ class ApplyClassifier {
22512257
return ShouldRecurse;
22522258
}
22532259

2260+
ShouldRecurse_t checkDefer(DeferStmt *S) {
2261+
return ShouldNotRecurse;
2262+
}
2263+
22542264
ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
22552265
return ShouldRecurse;
22562266
}
@@ -2350,6 +2360,10 @@ class ApplyClassifier {
23502360
return ShouldNotRecurse;
23512361
}
23522362

2363+
ShouldRecurse_t checkDefer(DeferStmt *S) {
2364+
return ShouldNotRecurse;
2365+
}
2366+
23532367
ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
23542368
return ShouldRecurse;
23552369
}
@@ -4394,6 +4408,17 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
43944408
return ShouldRecurse;
43954409
}
43964410

4411+
ShouldRecurse_t checkDefer(DeferStmt *S) {
4412+
// Pretend we're in an 'unsafe'.
4413+
ContextScope scope(*this, std::nullopt);
4414+
scope.enterUnsafe(S->getDeferLoc());
4415+
4416+
// Walk the call expression. We don't care about the rest.
4417+
S->getCallExpr()->walk(*this);
4418+
4419+
return ShouldNotRecurse;
4420+
}
4421+
43974422
void diagnoseRedundantTry(AnyTryExpr *E) const {
43984423
if (auto *SVE = SingleValueStmtExpr::tryDigOutSingleValueStmtExpr(E)) {
43994424
// For an if/switch expression, produce a tailored warning.

lib/Sema/TypeCheckUnsafe.cpp

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,12 @@ bool swift::enumerateUnsafeUses(ArrayRef<ProtocolConformanceRef> conformances,
327327
bool swift::enumerateUnsafeUses(SubstitutionMap subs,
328328
SourceLoc loc,
329329
llvm::function_ref<bool(UnsafeUse)> fn) {
330-
// FIXME: Check replacement types?
330+
// Replacement types.
331+
for (auto replacementType : subs.getReplacementTypes()) {
332+
if (replacementType->isUnsafe() &&
333+
fn(UnsafeUse::forReferenceToUnsafe(nullptr, false, replacementType, loc)))
334+
return true;
335+
}
331336

332337
// Check conformances.
333338
if (enumerateUnsafeUses(subs.getConformances(), loc, fn))
@@ -372,21 +377,73 @@ void swift::diagnoseUnsafeType(ASTContext &ctx, SourceLoc loc, Type type,
372377
if (!ctx.LangOpts.hasFeature(Feature::StrictMemorySafety))
373378
return;
374379

375-
if (!type->isUnsafe() && !type->getCanonicalType()->isUnsafe())
380+
if (!type->isUnsafe())
376381
return;
377382

378-
// Look for a specific @unsafe nominal type.
379-
Type specificType;
380-
type.findIf([&specificType](Type type) {
381-
if (auto typeDecl = type->getAnyNominal()) {
382-
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
383-
specificType = type;
384-
return false;
383+
// Look for a specific @unsafe nominal type along the way.
384+
class Walker : public TypeWalker {
385+
public:
386+
Type specificType;
387+
388+
Action walkToTypePre(Type type) override {
389+
if (specificType)
390+
return Action::Stop;
391+
392+
// If this refers to a nominal type that is @unsafe, store that.
393+
if (auto typeDecl = type->getAnyNominal()) {
394+
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
395+
specificType = type;
396+
return Action::Stop;
397+
}
398+
}
399+
400+
// Do not recurse into nominal types, because we do not want to visit
401+
// their "parent" types.
402+
if (isa<NominalOrBoundGenericNominalType>(type.getPointer()) ||
403+
isa<UnboundGenericType>(type.getPointer())) {
404+
// Recurse into the generic arguments. This operation is recursive,
405+
// because we also need to see the generic arguments of parent types.
406+
walkGenericArguments(type);
407+
408+
return Action::SkipNode;
409+
}
410+
411+
return Action::Continue;
412+
}
413+
414+
private:
415+
/// Recursively walk the generic arguments of this type and its parent
416+
/// types.
417+
void walkGenericArguments(Type type) {
418+
if (!type)
419+
return;
420+
421+
// Walk the generic arguments.
422+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
423+
for (auto genericArg : boundGeneric->getGenericArgs())
424+
genericArg.walk(*this);
385425
}
426+
427+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>())
428+
return walkGenericArguments(nominalOrBound->getParent());
429+
430+
if (auto unbound = type->getAs<UnboundGenericType>())
431+
return walkGenericArguments(unbound->getParent());
386432
}
433+
};
387434

388-
return false;
389-
});
435+
// Look for a canonical unsafe type.
436+
Walker walker;
437+
type->getCanonicalType().walk(walker);
438+
Type specificType = walker.specificType;
439+
440+
// Look for an unsafe type in the non-canonical type, which is a better answer
441+
// if we can find it.
442+
walker.specificType = Type();
443+
type.walk(walker);
444+
if (specificType && walker.specificType &&
445+
specificType->isEqual(walker.specificType))
446+
specificType = walker.specificType;
390447

391448
diagnose(specificType ? specificType : type);
392449
}

0 commit comments

Comments
 (0)