Skip to content

Tighten up unsafe effects checking #78585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -8093,6 +8093,9 @@ NOTE(note_reference_to_unsafe_typed_decl,none,
NOTE(note_reference_to_unsafe_through_typealias,none,
"reference to %kind0 whose underlying type involves unsafe type %1",
(const ValueDecl *, Type))
NOTE(note_reference_to_unsafe_type,none,
"reference to unsafe type %0",
(Type))
NOTE(note_reference_to_nonisolated_unsafe,none,
"reference to nonisolated(unsafe) %kind0 is unsafe in concurrently-executing code",
(const ValueDecl *))
Expand Down
27 changes: 27 additions & 0 deletions include/swift/AST/UnsafeUse.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,33 @@ class UnsafeUse {
}
}

/// Replace the location, if possible.
void replaceLocation(SourceLoc loc) {
switch (getKind()) {
case Override:
case Witness:
case PreconcurrencyImport:
// Cannot replace location.
return;

case UnsafeConformance:
storage.conformance.location = loc.getOpaquePointerValue();
break;

case TypeWitness:
storage.typeWitness.location = loc.getOpaquePointerValue();
break;

case UnownedUnsafe:
case ExclusivityUnchecked:
case NonisolatedUnsafe:
case ReferenceToUnsafe:
case ReferenceToUnsafeThroughTypealias:
case CallToUnsafe:
storage.entity.location = loc.getOpaquePointerValue();
}
}

/// Get the main declaration, when there is one.
const Decl *getDecl() const {
switch (getKind()) {
Expand Down
28 changes: 11 additions & 17 deletions lib/AST/ConformanceLookupTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,6 @@ DeclContext *ConformanceLookupTable::ConformanceSource::getDeclContext() const {
llvm_unreachable("Unhandled ConformanceEntryKind in switch.");
}

bool ConformanceLookupTable::ConformanceSource::isUnsafeContext(DeclContext *dc) {
if (auto enclosingNominal = dc->getSelfNominalTypeDecl())
if (enclosingNominal->isUnsafe())
return true;

if (auto ext = dyn_cast<ExtensionDecl>(dc))
if (ext->getAttrs().hasAttribute<UnsafeAttr>())
return true;

return false;
}

ProtocolDecl *ConformanceLookupTable::ConformanceEntry::getProtocol() const {
if (auto protocol = Conformance.dyn_cast<ProtocolDecl *>())
return protocol;
Expand Down Expand Up @@ -166,13 +154,17 @@ namespace {
/// The location of the "preconcurrency" attribute if present.
const SourceLoc preconcurrencyLoc;

/// The location of the "unsafe" attribute if present.
const SourceLoc unsafeLoc;

ConformanceConstructionInfo() { }

ConformanceConstructionInfo(ProtocolDecl *item, SourceLoc loc,
SourceLoc uncheckedLoc,
SourceLoc preconcurrencyLoc)
SourceLoc preconcurrencyLoc,
SourceLoc unsafeLoc)
: Located(item, loc), uncheckedLoc(uncheckedLoc),
preconcurrencyLoc(preconcurrencyLoc) {}
preconcurrencyLoc(preconcurrencyLoc), unsafeLoc(unsafeLoc) {}
};
}

Expand Down Expand Up @@ -228,7 +220,7 @@ void ConformanceLookupTable::forEachInStage(ConformanceStage stage,
registerProtocolConformances(next, conformances);
for (auto conf : conformances) {
protocols.push_back(
{conf->getProtocol(), SourceLoc(), SourceLoc(), SourceLoc()});
{conf->getProtocol(), SourceLoc(), SourceLoc(), SourceLoc(), SourceLoc()});
}
} else if (next->getParentSourceFile() ||
next->getParentModule()->isBuiltinModule()) {
Expand All @@ -238,7 +230,8 @@ void ConformanceLookupTable::forEachInStage(ConformanceStage stage,
getDirectlyInheritedNominalTypeDecls(next, inverses, anyObject)) {
if (auto proto = dyn_cast<ProtocolDecl>(found.Item))
protocols.push_back(
{proto, found.Loc, found.uncheckedLoc, found.preconcurrencyLoc});
{proto, found.Loc, found.uncheckedLoc,
found.preconcurrencyLoc, found.unsafeLoc});
}
}

Expand Down Expand Up @@ -343,7 +336,8 @@ void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal,
addProtocol(
locAndProto.Item, locAndProto.Loc,
source.withUncheckedLoc(locAndProto.uncheckedLoc)
.withPreconcurrencyLoc(locAndProto.preconcurrencyLoc));
.withPreconcurrencyLoc(locAndProto.preconcurrencyLoc)
.withUnsafeLoc(locAndProto.unsafeLoc));
});
break;

Expand Down
6 changes: 1 addition & 5 deletions lib/AST/ConformanceLookupTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class ConformanceLookupTable : public ASTAllocated<ConformanceLookupTable> {
options |= ProtocolConformanceFlags::Unchecked;
if (getPreconcurrencyLoc().isValid())
options |= ProtocolConformanceFlags::Preconcurrency;
if (getUnsafeLoc().isValid() || isUnsafeContext(getDeclContext()))
if (getUnsafeLoc().isValid())
options |= ProtocolConformanceFlags::Unsafe;
return options;
}
Expand Down Expand Up @@ -264,10 +264,6 @@ class ConformanceLookupTable : public ASTAllocated<ConformanceLookupTable> {
/// Get the declaration context that this conformance will be
/// associated with.
DeclContext *getDeclContext() const;

private:
/// Whether this declaration context is @unsafe.
static bool isUnsafeContext(DeclContext *dc);
};

/// An entry in the conformance table.
Expand Down
17 changes: 16 additions & 1 deletion lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3663,6 +3663,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
ASTContext &ctx = cs.getASTContext();
bool isAsync = stmt->getAwaitLoc().isValid();
auto *sequenceExpr = stmt->getParsedSequence();

// If we have an unsafe expression for the sequence, lift it out of the
// sequence expression. We'll put it back after we've introduced the
// various calls.
UnsafeExpr *unsafeExpr = dyn_cast<UnsafeExpr>(sequenceExpr);
if (unsafeExpr) {
sequenceExpr = unsafeExpr->getSubExpr();
}

auto contextualLocator = cs.getConstraintLocator(
sequenceExpr, LocatorPathElt::ContextualType(CTP_ForEachSequence));
auto elementLocator = cs.getConstraintLocator(
Expand Down Expand Up @@ -3712,9 +3721,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
ctx, sequenceExpr, makeIterator->getName());
makeIteratorRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply());

auto *makeIteratorCall =
Expr *makeIteratorCall =
CallExpr::createImplicitEmpty(ctx, makeIteratorRef);

// Swap in the 'unsafe' expression.
if (unsafeExpr) {
unsafeExpr->setSubExpr(makeIteratorCall);
makeIteratorCall = unsafeExpr;
}

Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar);
auto *PB = PatternBindingDecl::createImplicit(
ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc);
Expand Down
Loading