Skip to content

Commit b1a7b48

Browse files
committed
[sema] Wire up VarDecl parent pointers for case stmt related Var Decls
This is in preparation for fixing issues around SILGenPattern fallthrough emission and bad rename/edit all in scope of case stmt var decls. Specifically, I am going to ensure that we can get from any VarDecl in the following to any other VarDecl: switch x { case .a(let v1, let v2), .b(let v1, let v2): ... fallthrough case .c(let v1, let v2), .d(let v1, let v2): ... } This will be done by: 1. Pointing the var decls in .d at the corresponding var decls in .c. 2. Pointing the var decls in .c at the corresponding var decls in .b. 3. Pointing the var decls in .b at the corresponding var decls in .a. 4. Pointing the var decls in .a at the case stmt. Recognizing that we are asking for the next VarDecl, but have a case stmt, we check if we have a fallthrough case stmt. If so, follow down the fallthrough case stmts until you find a fallthrough case stmt that doesn't fallthrough itself and then return the corresponding var decl in the last case label item in that var decl (in the above .d). In a subsequent commit I am going to add case body var decls. The only change as a result of that is that I will insert them into the VarDecl double linked list after the last case var decl of each case stmt.
1 parent 500f34c commit b1a7b48

File tree

4 files changed

+283
-80
lines changed

4 files changed

+283
-80
lines changed

include/swift/AST/Decl.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4667,19 +4667,32 @@ class VarDecl : public AbstractStorageDecl {
46674667
/// returns null.
46684668
///
46694669
Pattern *getParentPattern() const;
4670-
4670+
46714671
/// Return the statement that owns the pattern associated with this VarDecl,
46724672
/// if one exists.
4673+
///
4674+
/// NOTE: After parsing and before type checking, all VarDecls from
4675+
/// CaseLabelItem's Patterns return their CaseStmt. After type checking, we
4676+
/// will have constructed the CaseLabelItem VarDecl linked list implying this
4677+
/// will return nullptr. After type checking, if one wishes to find a parent
4678+
/// pattern of a VarDecl of a CaseStmt, \see getRecursiveParentPatternStmt
4679+
/// instead.
46734680
Stmt *getParentPatternStmt() const {
46744681
if (!Parent)
46754682
return nullptr;
46764683
return Parent.dyn_cast<Stmt *>();
46774684
}
4685+
46784686
void setParentPatternStmt(Stmt *s) {
46794687
assert(s);
46804688
Parent = s;
46814689
}
46824690

4691+
/// Look for the parent pattern stmt of this var decl, recursively
4692+
/// looking through var decl pointers and then through any
4693+
/// fallthroughts.
4694+
Stmt *getRecursiveParentPatternStmt() const;
4695+
46834696
/// Returns the var decl that this var decl is an implicit reference to if
46844697
/// such a var decl exists.
46854698
VarDecl *getParentVarDecl() const {
@@ -4695,6 +4708,12 @@ class VarDecl : public AbstractStorageDecl {
46954708
Parent = v;
46964709
}
46974710

4711+
/// If this is a VarDecl that does not belong to a CaseLabelItem's pattern,
4712+
/// return this. Otherwise, this VarDecl must belong to a CaseStmt's
4713+
/// CaseLabelItem. In that case, return the first case label item of the first
4714+
/// case stmt in a sequence of case stmts that fallthrough into each other.
4715+
VarDecl *getCanonicalVarDecl() const;
4716+
46984717
/// True if the global stored property requires lazy initialization.
46994718
bool isLazilyInitializedGlobal() const;
47004719

lib/AST/Decl.cpp

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4874,14 +4874,74 @@ SourceRange VarDecl::getTypeSourceRangeForDiagnostics() const {
48744874
return SourceRange();
48754875
}
48764876

4877-
static bool isVarInPattern(const VarDecl *VD, Pattern *P) {
4877+
static bool isVarInPattern(const VarDecl *vd, Pattern *p) {
48784878
bool foundIt = false;
4879-
P->forEachVariable([&](VarDecl *FoundFD) {
4880-
foundIt |= FoundFD == VD;
4881-
});
4879+
p->forEachVariable([&](VarDecl *foundFD) { foundIt |= foundFD == vd; });
48824880
return foundIt;
48834881
}
48844882

4883+
static Optional<std::pair<CaseStmt *, Pattern *>>
4884+
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
4885+
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
4886+
for (auto &item : cs->getMutableCaseLabelItems()) {
4887+
if (isVarInPattern(inputVD, item.getPattern())) {
4888+
return item.getPattern();
4889+
}
4890+
}
4891+
return nullptr;
4892+
};
4893+
4894+
// First find our canonical var decl. This is the VarDecl corresponding to the
4895+
// first case label item of the first case block in the fallthrough chain that
4896+
// our case block is within. Grab the case stmt associated with that var decl
4897+
// and start traveling down the fallthrough chain looking for the case
4898+
// statement that the input VD belongs to by using getMatchingPattern().
4899+
auto *canonicalVD = inputVD->getCanonicalVarDecl();
4900+
auto *caseStmt =
4901+
dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt());
4902+
if (!caseStmt)
4903+
return None;
4904+
4905+
if (auto *p = getMatchingPattern(caseStmt))
4906+
return std::make_pair(caseStmt, p);
4907+
4908+
while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) {
4909+
if (auto *p = getMatchingPattern(caseStmt))
4910+
return std::make_pair(caseStmt, p);
4911+
}
4912+
4913+
return None;
4914+
}
4915+
4916+
VarDecl *VarDecl::getCanonicalVarDecl() const {
4917+
// Any var decl without a parent var decl is canonical. This means that before
4918+
// type checking, all var decls are canonical.
4919+
auto *cur = const_cast<VarDecl *>(this);
4920+
auto *vd = cur->getParentVarDecl();
4921+
if (!vd)
4922+
return cur;
4923+
4924+
while (vd) {
4925+
cur = vd;
4926+
vd = vd->getParentVarDecl();
4927+
}
4928+
4929+
return cur;
4930+
}
4931+
4932+
Stmt *VarDecl::getRecursiveParentPatternStmt() const {
4933+
// If our parent is already a pattern stmt, just return that.
4934+
if (auto *stmt = getParentPatternStmt())
4935+
return stmt;
4936+
4937+
// Otherwise, see if we have a parent var decl. If we do not, then return
4938+
// nullptr. Otherwise, return the case stmt that we found.
4939+
auto result = findParentPatternCaseStmtAndPattern(this);
4940+
if (!result.hasValue())
4941+
return nullptr;
4942+
return result->first;
4943+
}
4944+
48854945
/// Return the Pattern involved in initializing this VarDecl. Recall that the
48864946
/// Pattern may be involved in initializing more than just this one vardecl
48874947
/// though. For example, if this is a VarDecl for "x", the pattern may be
@@ -4903,25 +4963,35 @@ Pattern *VarDecl::getParentPattern() const {
49034963

49044964
if (auto *CS = dyn_cast<CatchStmt>(stmt))
49054965
return CS->getErrorPattern();
4906-
4966+
49074967
if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
49084968
// In a case statement, search for the pattern that contains it. This is
49094969
// a bit silly, because you can't have something like "case x, y:" anyway.
49104970
for (auto items : cs->getCaseLabelItems()) {
49114971
if (isVarInPattern(this, items.getPattern()))
49124972
return items.getPattern();
49134973
}
4914-
} else if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
4974+
}
4975+
4976+
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
49154977
for (auto &elt : LCS->getCond())
49164978
if (auto pat = elt.getPatternOrNull())
49174979
if (isVarInPattern(this, pat))
49184980
return pat;
49194981
}
4920-
4982+
49214983
//stmt->dump();
49224984
assert(0 && "Unknown parent pattern statement?");
49234985
}
4924-
4986+
4987+
// Otherwise, check if we have to walk our case stmt's var decl list to find
4988+
// the pattern.
4989+
if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) {
4990+
return caseStmtPatternPair->second;
4991+
}
4992+
4993+
// Otherwise, this is a case we do not know or understand. Return nullptr to
4994+
// signal we do not have any information.
49254995
return nullptr;
49264996
}
49274997

lib/Sema/MiscDiagnostics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2623,7 +2623,7 @@ VarDeclUsageChecker::~VarDeclUsageChecker() {
26232623
}
26242624
else {
26252625
bool suggestLet = true;
2626-
if (auto *stmt = var->getParentPatternStmt()) {
2626+
if (auto *stmt = var->getRecursiveParentPatternStmt()) {
26272627
// Don't try to suggest 'var' -> 'let' conversion
26282628
// in case of 'for' loop because it's an implicitly
26292629
// immutable context.

0 commit comments

Comments
 (0)