Skip to content

Commit 78608b3

Browse files
authored
Merge pull request #78585 from DougGregor/more-unsafe-effects-checking
Tighten up `unsafe` effects checking
2 parents a3173a0 + ff2ef7a commit 78608b3

12 files changed

+454
-108
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8093,6 +8093,9 @@ NOTE(note_reference_to_unsafe_typed_decl,none,
80938093
NOTE(note_reference_to_unsafe_through_typealias,none,
80948094
"reference to %kind0 whose underlying type involves unsafe type %1",
80958095
(const ValueDecl *, Type))
8096+
NOTE(note_reference_to_unsafe_type,none,
8097+
"reference to unsafe type %0",
8098+
(Type))
80968099
NOTE(note_reference_to_nonisolated_unsafe,none,
80978100
"reference to nonisolated(unsafe) %kind0 is unsafe in concurrently-executing code",
80988101
(const ValueDecl *))

include/swift/AST/UnsafeUse.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,33 @@ class UnsafeUse {
225225
}
226226
}
227227

228+
/// Replace the location, if possible.
229+
void replaceLocation(SourceLoc loc) {
230+
switch (getKind()) {
231+
case Override:
232+
case Witness:
233+
case PreconcurrencyImport:
234+
// Cannot replace location.
235+
return;
236+
237+
case UnsafeConformance:
238+
storage.conformance.location = loc.getOpaquePointerValue();
239+
break;
240+
241+
case TypeWitness:
242+
storage.typeWitness.location = loc.getOpaquePointerValue();
243+
break;
244+
245+
case UnownedUnsafe:
246+
case ExclusivityUnchecked:
247+
case NonisolatedUnsafe:
248+
case ReferenceToUnsafe:
249+
case ReferenceToUnsafeThroughTypealias:
250+
case CallToUnsafe:
251+
storage.entity.location = loc.getOpaquePointerValue();
252+
}
253+
}
254+
228255
/// Get the main declaration, when there is one.
229256
const Decl *getDecl() const {
230257
switch (getKind()) {

lib/AST/ConformanceLookupTable.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,6 @@ DeclContext *ConformanceLookupTable::ConformanceSource::getDeclContext() const {
5353
llvm_unreachable("Unhandled ConformanceEntryKind in switch.");
5454
}
5555

56-
bool ConformanceLookupTable::ConformanceSource::isUnsafeContext(DeclContext *dc) {
57-
if (auto enclosingNominal = dc->getSelfNominalTypeDecl())
58-
if (enclosingNominal->isUnsafe())
59-
return true;
60-
61-
if (auto ext = dyn_cast<ExtensionDecl>(dc))
62-
if (ext->getAttrs().hasAttribute<UnsafeAttr>())
63-
return true;
64-
65-
return false;
66-
}
67-
6856
ProtocolDecl *ConformanceLookupTable::ConformanceEntry::getProtocol() const {
6957
if (auto protocol = Conformance.dyn_cast<ProtocolDecl *>())
7058
return protocol;
@@ -166,13 +154,17 @@ namespace {
166154
/// The location of the "preconcurrency" attribute if present.
167155
const SourceLoc preconcurrencyLoc;
168156

157+
/// The location of the "unsafe" attribute if present.
158+
const SourceLoc unsafeLoc;
159+
169160
ConformanceConstructionInfo() { }
170161

171162
ConformanceConstructionInfo(ProtocolDecl *item, SourceLoc loc,
172163
SourceLoc uncheckedLoc,
173-
SourceLoc preconcurrencyLoc)
164+
SourceLoc preconcurrencyLoc,
165+
SourceLoc unsafeLoc)
174166
: Located(item, loc), uncheckedLoc(uncheckedLoc),
175-
preconcurrencyLoc(preconcurrencyLoc) {}
167+
preconcurrencyLoc(preconcurrencyLoc), unsafeLoc(unsafeLoc) {}
176168
};
177169
}
178170

@@ -228,7 +220,7 @@ void ConformanceLookupTable::forEachInStage(ConformanceStage stage,
228220
registerProtocolConformances(next, conformances);
229221
for (auto conf : conformances) {
230222
protocols.push_back(
231-
{conf->getProtocol(), SourceLoc(), SourceLoc(), SourceLoc()});
223+
{conf->getProtocol(), SourceLoc(), SourceLoc(), SourceLoc(), SourceLoc()});
232224
}
233225
} else if (next->getParentSourceFile() ||
234226
next->getParentModule()->isBuiltinModule()) {
@@ -238,7 +230,8 @@ void ConformanceLookupTable::forEachInStage(ConformanceStage stage,
238230
getDirectlyInheritedNominalTypeDecls(next, inverses, anyObject)) {
239231
if (auto proto = dyn_cast<ProtocolDecl>(found.Item))
240232
protocols.push_back(
241-
{proto, found.Loc, found.uncheckedLoc, found.preconcurrencyLoc});
233+
{proto, found.Loc, found.uncheckedLoc,
234+
found.preconcurrencyLoc, found.unsafeLoc});
242235
}
243236
}
244237

@@ -343,7 +336,8 @@ void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal,
343336
addProtocol(
344337
locAndProto.Item, locAndProto.Loc,
345338
source.withUncheckedLoc(locAndProto.uncheckedLoc)
346-
.withPreconcurrencyLoc(locAndProto.preconcurrencyLoc));
339+
.withPreconcurrencyLoc(locAndProto.preconcurrencyLoc)
340+
.withUnsafeLoc(locAndProto.unsafeLoc));
347341
});
348342
break;
349343

lib/AST/ConformanceLookupTable.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class ConformanceLookupTable : public ASTAllocated<ConformanceLookupTable> {
171171
options |= ProtocolConformanceFlags::Unchecked;
172172
if (getPreconcurrencyLoc().isValid())
173173
options |= ProtocolConformanceFlags::Preconcurrency;
174-
if (getUnsafeLoc().isValid() || isUnsafeContext(getDeclContext()))
174+
if (getUnsafeLoc().isValid())
175175
options |= ProtocolConformanceFlags::Unsafe;
176176
return options;
177177
}
@@ -264,10 +264,6 @@ class ConformanceLookupTable : public ASTAllocated<ConformanceLookupTable> {
264264
/// Get the declaration context that this conformance will be
265265
/// associated with.
266266
DeclContext *getDeclContext() const;
267-
268-
private:
269-
/// Whether this declaration context is @unsafe.
270-
static bool isUnsafeContext(DeclContext *dc);
271267
};
272268

273269
/// An entry in the conformance table.

lib/Sema/CSGen.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3663,6 +3663,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
36633663
ASTContext &ctx = cs.getASTContext();
36643664
bool isAsync = stmt->getAwaitLoc().isValid();
36653665
auto *sequenceExpr = stmt->getParsedSequence();
3666+
3667+
// If we have an unsafe expression for the sequence, lift it out of the
3668+
// sequence expression. We'll put it back after we've introduced the
3669+
// various calls.
3670+
UnsafeExpr *unsafeExpr = dyn_cast<UnsafeExpr>(sequenceExpr);
3671+
if (unsafeExpr) {
3672+
sequenceExpr = unsafeExpr->getSubExpr();
3673+
}
3674+
36663675
auto contextualLocator = cs.getConstraintLocator(
36673676
sequenceExpr, LocatorPathElt::ContextualType(CTP_ForEachSequence));
36683677
auto elementLocator = cs.getConstraintLocator(
@@ -3712,9 +3721,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
37123721
ctx, sequenceExpr, makeIterator->getName());
37133722
makeIteratorRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply());
37143723

3715-
auto *makeIteratorCall =
3724+
Expr *makeIteratorCall =
37163725
CallExpr::createImplicitEmpty(ctx, makeIteratorRef);
37173726

3727+
// Swap in the 'unsafe' expression.
3728+
if (unsafeExpr) {
3729+
unsafeExpr->setSubExpr(makeIteratorCall);
3730+
makeIteratorCall = unsafeExpr;
3731+
}
3732+
37183733
Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar);
37193734
auto *PB = PatternBindingDecl::createImplicit(
37203735
ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc);

0 commit comments

Comments
 (0)