Skip to content

Commit ff04fbd

Browse files
committed
Fix name lookup and parser so that where clauses decl refs in case statements correctly bind to the current pattern instead of always to the first pattern. Thus the hacky var decl juggling in where clauses in SILGen can be deleted.
1 parent 2a0c44e commit ff04fbd

File tree

4 files changed

+68
-117
lines changed

4 files changed

+68
-117
lines changed

lib/AST/NameLookupImpl.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,21 +214,20 @@ class FindLocalVal : public StmtVisitor<FindLocalVal> {
214214
return;
215215
// Pattern names aren't visible in the patterns themselves,
216216
// just in the body or in where guards.
217-
auto body = S->getBody();
218217
bool inPatterns = isReferencePointInRange(S->getLabelItemsRange());
219218
auto items = S->getCaseLabelItems();
220219
if (inPatterns) {
221220
for (const auto &CLI : items) {
222221
auto guard = CLI.getGuardExpr();
223222
if (guard && isReferencePointInRange(guard->getSourceRange())) {
224-
inPatterns = false;
223+
checkPattern(CLI.getPattern(), DeclVisibilityKind::LocalVariable);
225224
break;
226225
}
227226
}
228227
}
229228
if (!inPatterns && !items.empty())
230229
checkPattern(items[0].getPattern(), DeclVisibilityKind::LocalVariable);
231-
visit(body);
230+
visit(S->getBody());
232231
}
233232

234233
void visitDoCatchStmt(DoCatchStmt *S) {

lib/Parse/ParseStmt.cpp

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,47 @@ namespace {
968968
};
969969
} // unnamed namespace
970970

971+
static void parseWhereGuard(Parser &P, GuardedPattern &result,
972+
ParserStatus &status,
973+
GuardedPatternContext parsingContext,
974+
bool isExprBasic) {
975+
if (P.Tok.is(tok::kw_where)) {
976+
SyntaxParsingContext WhereClauseCtxt(P.SyntaxContext,
977+
SyntaxKind::WhereClause);
978+
result.WhereLoc = P.consumeToken(tok::kw_where);
979+
SourceLoc startOfGuard = P.Tok.getLoc();
980+
981+
auto diagKind = [=]() -> Diag<> {
982+
switch (parsingContext) {
983+
case GuardedPatternContext::Case:
984+
return diag::expected_case_where_expr;
985+
case GuardedPatternContext::Catch:
986+
return diag::expected_catch_where_expr;
987+
}
988+
llvm_unreachable("bad context");
989+
}();
990+
ParserResult<Expr> guardResult = P.parseExprImpl(diagKind, isExprBasic);
991+
status |= guardResult;
992+
993+
// Use the parsed guard expression if possible.
994+
if (guardResult.isNonNull()) {
995+
result.Guard = guardResult.get();
996+
997+
// Otherwise, fake up an ErrorExpr.
998+
} else {
999+
// If we didn't consume any tokens failing to parse the
1000+
// expression, don't put in the source range of the ErrorExpr.
1001+
SourceRange errorRange;
1002+
if (startOfGuard == P.Tok.getLoc()) {
1003+
errorRange = result.WhereLoc;
1004+
} else {
1005+
errorRange = SourceRange(startOfGuard, P.PreviousLoc);
1006+
}
1007+
result.Guard = new (P.Context) ErrorExpr(errorRange);
1008+
}
1009+
}
1010+
}
1011+
9711012
/// Parse a pattern-matching clause for a case or catch statement,
9721013
/// including the guard expression:
9731014
///
@@ -1045,7 +1086,6 @@ static void parseGuardedPattern(Parser &P, GuardedPattern &result,
10451086
patternResult = makeParserResult(varPattern);
10461087
}
10471088

1048-
10491089
// Okay, if the special code-completion didn't kick in, parse a
10501090
// matching pattern.
10511091
if (patternResult.isNull()) {
@@ -1074,26 +1114,28 @@ static void parseGuardedPattern(Parser &P, GuardedPattern &result,
10741114
if (VD->hasName()) P.addToScope(VD);
10751115
boundDecls.push_back(VD);
10761116
});
1117+
1118+
// Now that we have them, mark them as being initialized without a PBD.
1119+
for (auto VD : boundDecls)
1120+
VD->setHasNonPatternBindingInit();
1121+
1122+
// Parse the optional 'where' guard.
1123+
parseWhereGuard(P, result, status, parsingContext, isExprBasic);
10771124
} else {
10781125
// If boundDecls already contains variables, then we must match the
10791126
// same number and same names in this pattern as were declared in a
10801127
// previous pattern (and later we will make sure they have the same
10811128
// types).
1129+
Scope guardScope(&P, ScopeKind::CaseVars);
10821130
SmallVector<VarDecl*, 4> repeatedDecls;
10831131
patternResult.get()->forEachVariable([&](VarDecl *VD) {
10841132
if (!VD->hasName())
10851133
return;
10861134

1087-
for (auto repeat : repeatedDecls)
1088-
if (repeat->getName() == VD->getName())
1089-
P.addToScope(VD); // will diagnose a duplicate declaration
1090-
10911135
bool found = false;
10921136
for (auto previous : boundDecls) {
10931137
if (previous->hasName() && previous->getName() == VD->getName()) {
10941138
found = true;
1095-
// Use the same local discriminator.
1096-
VD->setLocalDiscriminator(previous->getLocalDiscriminator());
10971139
break;
10981140
}
10991141
}
@@ -1103,6 +1145,9 @@ static void parseGuardedPattern(Parser &P, GuardedPattern &result,
11031145
status.setIsParseError();
11041146
}
11051147
repeatedDecls.push_back(VD);
1148+
P.setLocalDiscriminator(VD);
1149+
if (VD->hasName())
1150+
P.addToScope(VD);
11061151
});
11071152

11081153
for (auto previous : boundDecls) {
@@ -1124,47 +1169,10 @@ static void parseGuardedPattern(Parser &P, GuardedPattern &result,
11241169
VD->setHasNonPatternBindingInit();
11251170
VD->setImplicit();
11261171
}
1127-
}
1128-
1129-
// Now that we have them, mark them as being initialized without a PBD.
1130-
for (auto VD : boundDecls)
1131-
VD->setHasNonPatternBindingInit();
1132-
1133-
// Parse the optional 'where' guard.
1134-
if (P.Tok.is(tok::kw_where)) {
1135-
SyntaxParsingContext WhereClauseCtxt(P.SyntaxContext,
1136-
SyntaxKind::WhereClause);
1137-
result.WhereLoc = P.consumeToken(tok::kw_where);
1138-
SourceLoc startOfGuard = P.Tok.getLoc();
11391172

1140-
auto diagKind = [=]() -> Diag<> {
1141-
switch (parsingContext) {
1142-
case GuardedPatternContext::Case:
1143-
return diag::expected_case_where_expr;
1144-
case GuardedPatternContext::Catch:
1145-
return diag::expected_catch_where_expr;
1146-
}
1147-
llvm_unreachable("bad context");
1148-
}();
1149-
ParserResult<Expr> guardResult = P.parseExprImpl(diagKind, isExprBasic);
1150-
status |= guardResult;
1151-
1152-
// Use the parsed guard expression if possible.
1153-
if (guardResult.isNonNull()) {
1154-
result.Guard = guardResult.get();
1155-
1156-
// Otherwise, fake up an ErrorExpr.
1157-
} else {
1158-
// If we didn't consume any tokens failing to parse the
1159-
// expression, don't put in the source range of the ErrorExpr.
1160-
SourceRange errorRange;
1161-
if (startOfGuard == P.Tok.getLoc()) {
1162-
errorRange = result.WhereLoc;
1163-
} else {
1164-
errorRange = SourceRange(startOfGuard, P.PreviousLoc);
1165-
}
1166-
result.Guard = new (P.Context) ErrorExpr(errorRange);
1167-
}
1173+
// Parse the optional 'where' guard, with this particular pattern's bound
1174+
// vars in scope.
1175+
parseWhereGuard(P, result, status, parsingContext, isExprBasic);
11681176
}
11691177
}
11701178

lib/SILGen/SILGenPattern.cpp

Lines changed: 10 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,9 @@ class PatternMatchEmission {
451451

452452
void bindRefutablePatterns(const ClauseRow &row, ArgArray args,
453453
const FailureHandler &failure);
454+
454455
void emitGuardBranch(SILLocation loc, Expr *guard,
455-
const FailureHandler &failure,
456-
Pattern *usingImplicitVariablesFromPattern,
457-
CaseStmt *usingImplicitVariablesFromStmt);
456+
const FailureHandler &failure);
458457

459458
void bindIrrefutablePatterns(const ClauseRow &row, ArgArray args,
460459
bool forIrrefutableRow, bool hasMultipleItems);
@@ -1076,19 +1075,17 @@ void PatternMatchEmission::emitWildcardDispatch(ClauseMatrix &clauses,
10761075
auto stmt = clauses[row].getClientData<Stmt>();
10771076
assert(isa<CaseStmt>(stmt) || isa<CatchStmt>(stmt));
10781077

1079-
bool hasMultipleItems = false;
1080-
if (auto *caseStmt = dyn_cast<CaseStmt>(stmt)) {
1081-
hasMultipleItems = clauses[row].hasFallthroughTo() ||
1082-
caseStmt->getCaseLabelItems().size() > 1;
1083-
}
1078+
auto *caseStmt = dyn_cast<CaseStmt>(stmt);
1079+
bool hasMultipleItems =
1080+
caseStmt && (clauses[row].hasFallthroughTo() ||
1081+
caseStmt->getCaseLabelItems().size() > 1);
10841082

10851083
// Bind the rest of the patterns.
10861084
bindIrrefutablePatterns(clauses[row], args, !hasGuard, hasMultipleItems);
10871085

10881086
// Emit the guard branch, if it exists.
10891087
if (guardExpr) {
1090-
this->emitGuardBranch(guardExpr, guardExpr, failure,
1091-
clauses[row].getCasePattern(), dyn_cast<CaseStmt>(stmt));
1088+
this->emitGuardBranch(guardExpr, guardExpr, failure);
10921089
}
10931090

10941091
// Enter the row.
@@ -1120,8 +1117,7 @@ bindRefutablePatterns(const ClauseRow &row, ArgArray args,
11201117
FullExpr scope(SGF.Cleanups, CleanupLocation(pattern));
11211118
bindVariable(pattern, exprPattern->getMatchVar(), args[i],
11221119
/*isForSuccess*/ false, /* hasMultipleItems */ false);
1123-
emitGuardBranch(pattern, exprPattern->getMatchExpr(), failure, nullptr,
1124-
nullptr);
1120+
emitGuardBranch(pattern, exprPattern->getMatchExpr(), failure);
11251121
break;
11261122
}
11271123
default:
@@ -1197,26 +1193,15 @@ void PatternMatchEmission::bindVariable(Pattern *pattern, VarDecl *var,
11971193
/// Evaluate a guard expression and, if it returns false, branch to
11981194
/// the given destination.
11991195
void PatternMatchEmission::emitGuardBranch(SILLocation loc, Expr *guard,
1200-
const FailureHandler &failure,
1201-
Pattern *usingImplicitVariablesFromPattern,
1202-
CaseStmt *usingImplicitVariablesFromStmt) {
1196+
const FailureHandler &failure) {
12031197
SILBasicBlock *falseBB = SGF.B.splitBlockForFallthrough();
12041198
SILBasicBlock *trueBB = SGF.B.splitBlockForFallthrough();
12051199

12061200
// Emit the match test.
12071201
SILValue testBool;
12081202
{
12091203
FullExpr scope(SGF.Cleanups, CleanupLocation(guard));
1210-
auto emitTest = [&]{
1211-
testBool = SGF.emitRValueAsSingleValue(guard).getUnmanagedValue();
1212-
};
1213-
1214-
if (usingImplicitVariablesFromPattern)
1215-
SGF.usingImplicitVariablesForPattern(usingImplicitVariablesFromPattern,
1216-
usingImplicitVariablesFromStmt,
1217-
emitTest);
1218-
else
1219-
emitTest();
1204+
testBool = SGF.emitRValueAsSingleValue(guard).getUnmanagedValue();
12201205
}
12211206

12221207
SGF.B.createCondBranch(loc, testBool, trueBB, falseBB);
@@ -2442,47 +2427,6 @@ class Lowering::PatternMatchContext {
24422427
PatternMatchEmission &Emission;
24432428
};
24442429

2445-
void SILGenFunction::usingImplicitVariablesForPattern(Pattern *pattern, CaseStmt *stmt,
2446-
const llvm::function_ref<void(void)> &f) {
2447-
// Early exit for CatchStmt
2448-
if (!stmt) {
2449-
f();
2450-
return;
2451-
}
2452-
2453-
ArrayRef<CaseLabelItem> labelItems = stmt->getCaseLabelItems();
2454-
auto expectedPattern = labelItems[0].getPattern();
2455-
2456-
if (labelItems.size() <= 1 || pattern == expectedPattern) {
2457-
f();
2458-
return;
2459-
}
2460-
2461-
// Remap vardecls that the case body is expecting to the pattern var locations
2462-
// for the given pattern, emit whatever, and switch back.
2463-
SmallVector<VarDecl *, 4> Vars;
2464-
expectedPattern->collectVariables(Vars);
2465-
2466-
auto variableSwapper = [&] {
2467-
pattern->forEachVariable([&](VarDecl *VD) {
2468-
if (!VD->hasName())
2469-
return;
2470-
for (auto expected : Vars) {
2471-
if (expected->hasName() && expected->getName() == VD->getName()) {
2472-
auto swap = VarLocs[expected];
2473-
VarLocs[expected] = VarLocs[VD];
2474-
VarLocs[VD] = swap;
2475-
return;
2476-
}
2477-
}
2478-
});
2479-
};
2480-
2481-
variableSwapper();
2482-
f();
2483-
variableSwapper();
2484-
}
2485-
24862430
static void emitDiagnoseOfUnexpectedEnumCaseValue(SILGenFunction &SGF,
24872431
SILLocation loc,
24882432
ManagedValue value,

test/refactoring/rename/Outputs/local/casebind_2.swift.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func test3(arg: Int?) {
2121
case let .some(x) where x == 0:
2222
print(x)
2323
case let .some(xRenamed) where xRenamed == 1,
24-
let .some(x) where xRenamed == 2: // FIXME: This 'x' in '.some(x)' isn't properly renamed in 'casebind_2' case.
24+
let .some(x) where x == 2: // FIXME: This 'x' in '.some(x)' isn't properly renamed in 'casebind_2' case.
2525
print(xRenamed)
2626
default:
2727
break

0 commit comments

Comments
 (0)