Skip to content

Commit 50801f9

Browse files
committed
[SE-0458] Implement "unsafe" effect for the for-in loop
Memory unsafety in the iteration part of the for-in loop (i.e., the part that works on the iterator) can be covered by the "unsafe" effect on the for..in loop, before the pattern.
1 parent 53b3460 commit 50801f9

File tree

13 files changed

+308
-109
lines changed

13 files changed

+308
-109
lines changed

include/swift/AST/ASTBridging.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2289,11 +2289,12 @@ BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
22892289
BridgedDeclContext cDC);
22902290

22912291
SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:"
2292-
"pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
2292+
"unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
22932293
BridgedForEachStmt BridgedForEachStmt_createParsed(
22942294
BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo,
22952295
BridgedSourceLoc cForLoc, BridgedSourceLoc cTryLoc,
2296-
BridgedSourceLoc cAwaitLoc, BridgedPattern cPat, BridgedSourceLoc cInLoc,
2296+
BridgedSourceLoc cAwaitLoc, BridgedSourceLoc cUnsafeLoc,
2297+
BridgedPattern cPat, BridgedSourceLoc cInLoc,
22972298
BridgedExpr cSequence, BridgedSourceLoc cWhereLoc,
22982299
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody);
22992300

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8212,8 +8212,12 @@ GROUPED_WARNING(preconcurrency_import_unsafe,Unsafe,none,
82128212
"introduce data races", ())
82138213
GROUPED_WARNING(unsafe_without_unsafe,Unsafe,none,
82148214
"expression uses unsafe constructs but is not marked with 'unsafe'", ())
8215+
GROUPED_WARNING(for_unsafe_without_unsafe,Unsafe,none,
8216+
"for-in loop uses unsafe constructs but is not marked with 'unsafe'", ())
82158217
WARNING(no_unsafe_in_unsafe,none,
82168218
"no unsafe operations occur within 'unsafe' expression", ())
8219+
WARNING(no_unsafe_in_unsafe_for,none,
8220+
"no unsafe operations occur within 'unsafe' for-in loop", ())
82178221
NOTE(make_subclass_unsafe,none,
82188222
"make class %0 @unsafe to allow unsafe overrides of safe superclass methods",
82198223
(DeclName))

include/swift/AST/Stmt.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,7 @@ class ForEachStmt : public LabeledStmt {
995995
SourceLoc ForLoc;
996996
SourceLoc TryLoc;
997997
SourceLoc AwaitLoc;
998+
SourceLoc UnsafeLoc;
998999
Pattern *Pat;
9991000
SourceLoc InLoc;
10001001
Expr *Sequence;
@@ -1012,13 +1013,14 @@ class ForEachStmt : public LabeledStmt {
10121013

10131014
public:
10141015
ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, SourceLoc TryLoc,
1015-
SourceLoc AwaitLoc, Pattern *Pat, SourceLoc InLoc, Expr *Sequence,
1016+
SourceLoc AwaitLoc, SourceLoc UnsafeLoc, Pattern *Pat,
1017+
SourceLoc InLoc, Expr *Sequence,
10161018
SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body,
10171019
std::optional<bool> implicit = std::nullopt)
10181020
: LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc),
10191021
LabelInfo),
1020-
ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), Pat(nullptr),
1021-
InLoc(InLoc), Sequence(Sequence), WhereLoc(WhereLoc),
1022+
ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), UnsafeLoc(UnsafeLoc),
1023+
Pat(nullptr), InLoc(InLoc), Sequence(Sequence), WhereLoc(WhereLoc),
10221024
WhereExpr(WhereExpr), Body(Body) {
10231025
setPattern(Pat);
10241026
}
@@ -1056,6 +1058,7 @@ class ForEachStmt : public LabeledStmt {
10561058

10571059
SourceLoc getAwaitLoc() const { return AwaitLoc; }
10581060
SourceLoc getTryLoc() const { return TryLoc; }
1061+
SourceLoc getUnsafeLoc() const { return UnsafeLoc; }
10591062

10601063
/// getPattern - Retrieve the pattern describing the iteration variables.
10611064
/// These variables will only be visible within the body of the loop.

include/swift/Parse/IDEInspectionCallbacks.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ class CodeCompletionCallbacks {
287287
virtual void completeStmtLabel(StmtKind ParentKind) {};
288288

289289
virtual
290-
void completeForEachPatternBeginning(bool hasTry, bool hasAwait) {};
290+
void completeForEachPatternBeginning(
291+
bool hasTry, bool hasAwait, bool hasUnsafe) {};
291292

292293
virtual void completeTypeAttrBeginning() {};
293294

lib/AST/Bridging/StmtBridging.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,15 @@ BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
191191
BridgedForEachStmt BridgedForEachStmt_createParsed(
192192
BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo,
193193
BridgedSourceLoc cForLoc, BridgedSourceLoc cTryLoc,
194-
BridgedSourceLoc cAwaitLoc, BridgedPattern cPat, BridgedSourceLoc cInLoc,
194+
BridgedSourceLoc cAwaitLoc, BridgedSourceLoc cUnsafeLoc,
195+
BridgedPattern cPat, BridgedSourceLoc cInLoc,
195196
BridgedExpr cSequence, BridgedSourceLoc cWhereLoc,
196197
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody) {
197198
return new (cContext.unbridged()) ForEachStmt(
198199
cLabelInfo.unbridged(), cForLoc.unbridged(), cTryLoc.unbridged(),
199-
cAwaitLoc.unbridged(), cPat.unbridged(), cInLoc.unbridged(),
200-
cSequence.unbridged(), cWhereLoc.unbridged(), cWhereExpr.unbridged(),
201-
cBody.unbridged());
200+
cAwaitLoc.unbridged(), cUnsafeLoc.unbridged(), cPat.unbridged(),
201+
cInLoc.unbridged(), cSequence.unbridged(), cWhereLoc.unbridged(),
202+
cWhereExpr.unbridged(), cBody.unbridged());
202203
}
203204

204205
BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext,

lib/ASTGen/Sources/ASTGen/Stmts.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ extension ASTGenVisitor {
323323
forLoc: self.generateSourceLoc(node.forKeyword),
324324
tryLoc: self.generateSourceLoc(node.tryKeyword),
325325
awaitLoc: self.generateSourceLoc(node.awaitKeyword),
326+
unsafeLoc: self.generateSourceLoc(node.unsafeKeyword),
326327
// NOTE: The pattern can be either a refutable pattern after `case` or a
327328
// normal binding pattern. ASTGen doesn't care because it should be handled
328329
// by the parser.

lib/IDE/CodeCompletion.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ class CodeCompletionCallbacksImpl : public CodeCompletionCallbacks,
302302
void completeGenericRequirement() override;
303303
void completeAfterIfStmtElse() override;
304304
void completeStmtLabel(StmtKind ParentKind) override;
305-
void completeForEachPatternBeginning(bool hasTry, bool hasAwait) override;
305+
void completeForEachPatternBeginning(
306+
bool hasTry, bool hasAwait, bool hasUnsafe) override;
306307
void completeTypeAttrBeginning() override;
307308
void completeTypeAttrInheritanceBeginning() override;
308309
void completeOptionalBinding() override;
@@ -636,14 +637,16 @@ void CodeCompletionCallbacksImpl::completeStmtLabel(StmtKind ParentKind) {
636637
}
637638

638639
void CodeCompletionCallbacksImpl::completeForEachPatternBeginning(
639-
bool hasTry, bool hasAwait) {
640+
bool hasTry, bool hasAwait, bool hasUnsafe) {
640641
CurDeclContext = P.CurDeclContext;
641642
Kind = CompletionKind::ForEachPatternBeginning;
642643
ParsedKeywords.clear();
643644
if (hasTry)
644645
ParsedKeywords.emplace_back("try");
645646
if (hasAwait)
646647
ParsedKeywords.emplace_back("await");
648+
if (hasUnsafe)
649+
ParsedKeywords.emplace_back("unsafe");
647650
}
648651

649652
void CodeCompletionCallbacksImpl::completeOptionalBinding() {

lib/Parse/ParseStmt.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,6 +2395,7 @@ ParserResult<Stmt> Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) {
23952395
auto StartOfControl = Tok.getLoc();
23962396
SourceLoc AwaitLoc;
23972397
SourceLoc TryLoc;
2398+
SourceLoc UnsafeLoc;
23982399

23992400
if (Tok.isContextualKeyword("await")) {
24002401
AwaitLoc = consumeToken();
@@ -2405,10 +2406,15 @@ ParserResult<Stmt> Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) {
24052406
}
24062407
}
24072408

2409+
if (Context.LangOpts.hasFeature(Feature::WarnUnsafe) &&
2410+
Tok.isContextualKeyword("unsafe")) {
2411+
UnsafeLoc = consumeToken();
2412+
}
2413+
24082414
if (Tok.is(tok::code_complete)) {
24092415
if (CodeCompletionCallbacks) {
24102416
CodeCompletionCallbacks->completeForEachPatternBeginning(
2411-
TryLoc.isValid(), AwaitLoc.isValid());
2417+
TryLoc.isValid(), AwaitLoc.isValid(), UnsafeLoc.isValid());
24122418
}
24132419
consumeToken(tok::code_complete);
24142420
// Since 'completeForeachPatternBeginning' is a keyword only completion,
@@ -2522,7 +2528,8 @@ ParserResult<Stmt> Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) {
25222528

25232529
return makeParserResult(
25242530
Status,
2525-
new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, pattern.get(), InLoc,
2531+
new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, UnsafeLoc,
2532+
pattern.get(), InLoc,
25262533
Container.get(), WhereLoc, Where.getPtrOrNull(),
25272534
Body.get()));
25282535
}

lib/Sema/BuilderTransform.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ class ResultBuilderTransform
726726
auto *newForEach = new (ctx)
727727
ForEachStmt(forEachStmt->getLabelInfo(), forEachStmt->getForLoc(),
728728
forEachStmt->getTryLoc(), forEachStmt->getAwaitLoc(),
729+
forEachStmt->getUnsafeLoc(),
729730
forEachStmt->getPattern(), forEachStmt->getInLoc(),
730731
forEachStmt->getParsedSequence(),
731732
forEachStmt->getWhereLoc(), forEachStmt->getWhere(),

lib/Sema/CSGen.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4632,10 +4632,11 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
46324632
AwaitExpr::createImplicit(ctx, nextCall->getLoc(), nextCall);
46334633
}
46344634

4635-
// Wrap the 'next' call in 'unsafe', if there is one.
4636-
if (unsafeExpr) {
4637-
nextCall = new (ctx) UnsafeExpr(unsafeExpr->getLoc(), nextCall, Type(),
4638-
/*implicit=*/true);
4635+
// Wrap the 'next' call in 'unsafe', if the for..in loop has that
4636+
// effect.
4637+
if (stmt->getUnsafeLoc().isValid()) {
4638+
nextCall = new (ctx) UnsafeExpr(
4639+
stmt->getUnsafeLoc(), nextCall, Type(), /*implicit=*/true);
46394640
}
46404641

46414642
// The iterator type must conform to IteratorProtocol.

0 commit comments

Comments
 (0)