Skip to content

Commit e2dbf67

Browse files
committed
SILGen: Push 'usingImplicitVariablesForPattern' hack into 'where' clause expr evaluation.
We swapped the pattern variables at the wrong level, leaving them bound incorrectly on the cleanup path through a failed 'where' clause check. Fixes rdar://problem/31539726.
1 parent 70a110a commit e2dbf67

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

lib/SILGen/SILGenPattern.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ class PatternMatchEmission {
422422
void bindExprPattern(ExprPattern *pattern, ConsumableManagedValue v,
423423
const FailureHandler &failure);
424424
void emitGuardBranch(SILLocation loc, Expr *guard,
425-
const FailureHandler &failure);
425+
const FailureHandler &failure,
426+
Pattern *usingImplicitVariablesFromPattern,
427+
CaseStmt *usingImplicitVariablesFromStmt);
426428

427429
void bindIrrefutablePatterns(const ClauseRow &row, ArgArray args,
428430
bool forIrrefutableRow, bool hasMultipleItems);
@@ -1056,9 +1058,8 @@ void PatternMatchEmission::emitWildcardDispatch(ClauseMatrix &clauses,
10561058

10571059
// Emit the guard branch, if it exists.
10581060
if (guardExpr) {
1059-
SGF.usingImplicitVariablesForPattern(clauses[row].getCasePattern(), dyn_cast<CaseStmt>(stmt), [&]{
1060-
this->emitGuardBranch(guardExpr, guardExpr, failure);
1061-
});
1061+
this->emitGuardBranch(guardExpr, guardExpr, failure,
1062+
clauses[row].getCasePattern(), dyn_cast<CaseStmt>(stmt));
10621063
}
10631064

10641065
// Enter the row.
@@ -1120,7 +1121,8 @@ void PatternMatchEmission::bindExprPattern(ExprPattern *pattern,
11201121
bindVariable(pattern, pattern->getMatchVar(), value,
11211122
pattern->getType()->getCanonicalType(),
11221123
/*isForSuccess*/ false, /* hasMultipleItems */ false);
1123-
emitGuardBranch(pattern, pattern->getMatchExpr(), failure);
1124+
emitGuardBranch(pattern, pattern->getMatchExpr(), failure,
1125+
nullptr, nullptr);
11241126
}
11251127

11261128
/// Bind all the irrefutable patterns in the given row, which is nothing
@@ -1229,15 +1231,26 @@ void PatternMatchEmission::bindVariable(SILLocation loc, VarDecl *var,
12291231
/// Evaluate a guard expression and, if it returns false, branch to
12301232
/// the given destination.
12311233
void PatternMatchEmission::emitGuardBranch(SILLocation loc, Expr *guard,
1232-
const FailureHandler &failure) {
1234+
const FailureHandler &failure,
1235+
Pattern *usingImplicitVariablesFromPattern,
1236+
CaseStmt *usingImplicitVariablesFromStmt) {
12331237
SILBasicBlock *falseBB = SGF.B.splitBlockForFallthrough();
12341238
SILBasicBlock *trueBB = SGF.B.splitBlockForFallthrough();
12351239

12361240
// Emit the match test.
12371241
SILValue testBool;
12381242
{
12391243
FullExpr scope(SGF.Cleanups, CleanupLocation(guard));
1240-
testBool = SGF.emitRValueAsSingleValue(guard).getUnmanagedValue();
1244+
auto emitTest = [&]{
1245+
testBool = SGF.emitRValueAsSingleValue(guard).getUnmanagedValue();
1246+
};
1247+
1248+
if (usingImplicitVariablesFromPattern)
1249+
SGF.usingImplicitVariablesForPattern(usingImplicitVariablesFromPattern,
1250+
usingImplicitVariablesFromStmt,
1251+
emitTest);
1252+
else
1253+
emitTest();
12411254
}
12421255

12431256
SGF.B.createCondBranch(loc, testBool, trueBB, falseBB);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
6+
func boo(_: LifetimeTracked) -> Bool { return true }
7+
8+
var tests = TestSuite("switches with where clauses")
9+
10+
enum Foo {
11+
case a(LifetimeTracked)
12+
case b(LifetimeTracked)
13+
}
14+
15+
func foo(_ x: Foo, _ y: Foo, _ condition: (LifetimeTracked) -> Bool) -> Bool {
16+
switch (x, y) {
17+
case (.a(let xml), _),
18+
(_, .a(let xml)) where condition(xml):
19+
return true
20+
default:
21+
return false
22+
}
23+
}
24+
25+
tests.test("all paths through a switch with guard") {
26+
_ = foo(.a(LifetimeTracked(0)), .a(LifetimeTracked(1)), { _ in true })
27+
_ = foo(.a(LifetimeTracked(2)), .b(LifetimeTracked(3)), { _ in true })
28+
_ = foo(.b(LifetimeTracked(4)), .a(LifetimeTracked(5)), { _ in true })
29+
_ = foo(.b(LifetimeTracked(6)), .b(LifetimeTracked(7)), { _ in true })
30+
31+
_ = foo(.a(LifetimeTracked(10)), .a(LifetimeTracked(11)), { _ in false })
32+
_ = foo(.a(LifetimeTracked(12)), .b(LifetimeTracked(13)), { _ in false })
33+
_ = foo(.b(LifetimeTracked(14)), .a(LifetimeTracked(15)), { _ in false })
34+
_ = foo(.b(LifetimeTracked(16)), .b(LifetimeTracked(17)), { _ in false })
35+
}
36+
37+
runAllTests()

0 commit comments

Comments
 (0)