Skip to content

Commit 4c389dd

Browse files
committed
[sema] Wire up VarDecl parent pointers for case stmt related Var Decls
This is in preparation for fixing issues around SILGenPattern fallthrough emission and bad rename/edit all in scope of case stmt var decls. Specifically, I am going to ensure that we can get from any VarDecl in the following to any other VarDecl: switch x { case .a(let v1, let v2), .b(let v1, let v2): ... fallthrough case .c(let v1, let v2), .d(let v1, let v2): ... } This will be done by: 1. Pointing the var decls in .d at the corresponding var decls in .c. 2. Pointing the var decls in .c at the corresponding var decls in .b. 3. Pointing the var decls in .b at the corresponding var decls in .a. 4. Pointing the var decls in .a at the case stmt. Recognizing that we are asking for the next VarDecl, but have a case stmt, we check if we have a fallthrough case stmt. If so, follow down the fallthrough case stmts until you find a fallthrough case stmt that doesn't fallthrough itself and then return the corresponding var decl in the last case label item in that var decl (in the above .d). In a subsequent commit I am going to add case body var decls. The only change as a result of that is that I will insert them into the VarDecl double linked list after the last case var decl of each case stmt. (cherry picked from commit b1a7b48)
1 parent ac2e8fc commit 4c389dd

File tree

4 files changed

+283
-80
lines changed

4 files changed

+283
-80
lines changed

include/swift/AST/Decl.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4669,19 +4669,32 @@ class VarDecl : public AbstractStorageDecl {
46694669
/// returns null.
46704670
///
46714671
Pattern *getParentPattern() const;
4672-
4672+
46734673
/// Return the statement that owns the pattern associated with this VarDecl,
46744674
/// if one exists.
4675+
///
4676+
/// NOTE: After parsing and before type checking, all VarDecls from
4677+
/// CaseLabelItem's Patterns return their CaseStmt. After type checking, we
4678+
/// will have constructed the CaseLabelItem VarDecl linked list implying this
4679+
/// will return nullptr. After type checking, if one wishes to find a parent
4680+
/// pattern of a VarDecl of a CaseStmt, \see getRecursiveParentPatternStmt
4681+
/// instead.
46754682
Stmt *getParentPatternStmt() const {
46764683
if (!Parent)
46774684
return nullptr;
46784685
return Parent.dyn_cast<Stmt *>();
46794686
}
4687+
46804688
void setParentPatternStmt(Stmt *s) {
46814689
assert(s);
46824690
Parent = s;
46834691
}
46844692

4693+
/// Look for the parent pattern stmt of this var decl, recursively
4694+
/// looking through var decl pointers and then through any
4695+
/// fallthroughts.
4696+
Stmt *getRecursiveParentPatternStmt() const;
4697+
46854698
/// Returns the var decl that this var decl is an implicit reference to if
46864699
/// such a var decl exists.
46874700
VarDecl *getParentVarDecl() const {
@@ -4697,6 +4710,12 @@ class VarDecl : public AbstractStorageDecl {
46974710
Parent = v;
46984711
}
46994712

4713+
/// If this is a VarDecl that does not belong to a CaseLabelItem's pattern,
4714+
/// return this. Otherwise, this VarDecl must belong to a CaseStmt's
4715+
/// CaseLabelItem. In that case, return the first case label item of the first
4716+
/// case stmt in a sequence of case stmts that fallthrough into each other.
4717+
VarDecl *getCanonicalVarDecl() const;
4718+
47004719
/// True if the global stored property requires lazy initialization.
47014720
bool isLazilyInitializedGlobal() const;
47024721

lib/AST/Decl.cpp

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4876,14 +4876,74 @@ SourceRange VarDecl::getTypeSourceRangeForDiagnostics() const {
48764876
return SourceRange();
48774877
}
48784878

4879-
static bool isVarInPattern(const VarDecl *VD, Pattern *P) {
4879+
static bool isVarInPattern(const VarDecl *vd, Pattern *p) {
48804880
bool foundIt = false;
4881-
P->forEachVariable([&](VarDecl *FoundFD) {
4882-
foundIt |= FoundFD == VD;
4883-
});
4881+
p->forEachVariable([&](VarDecl *foundFD) { foundIt |= foundFD == vd; });
48844882
return foundIt;
48854883
}
48864884

4885+
static Optional<std::pair<CaseStmt *, Pattern *>>
4886+
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
4887+
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
4888+
for (auto &item : cs->getMutableCaseLabelItems()) {
4889+
if (isVarInPattern(inputVD, item.getPattern())) {
4890+
return item.getPattern();
4891+
}
4892+
}
4893+
return nullptr;
4894+
};
4895+
4896+
// First find our canonical var decl. This is the VarDecl corresponding to the
4897+
// first case label item of the first case block in the fallthrough chain that
4898+
// our case block is within. Grab the case stmt associated with that var decl
4899+
// and start traveling down the fallthrough chain looking for the case
4900+
// statement that the input VD belongs to by using getMatchingPattern().
4901+
auto *canonicalVD = inputVD->getCanonicalVarDecl();
4902+
auto *caseStmt =
4903+
dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt());
4904+
if (!caseStmt)
4905+
return None;
4906+
4907+
if (auto *p = getMatchingPattern(caseStmt))
4908+
return std::make_pair(caseStmt, p);
4909+
4910+
while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) {
4911+
if (auto *p = getMatchingPattern(caseStmt))
4912+
return std::make_pair(caseStmt, p);
4913+
}
4914+
4915+
return None;
4916+
}
4917+
4918+
VarDecl *VarDecl::getCanonicalVarDecl() const {
4919+
// Any var decl without a parent var decl is canonical. This means that before
4920+
// type checking, all var decls are canonical.
4921+
auto *cur = const_cast<VarDecl *>(this);
4922+
auto *vd = cur->getParentVarDecl();
4923+
if (!vd)
4924+
return cur;
4925+
4926+
while (vd) {
4927+
cur = vd;
4928+
vd = vd->getParentVarDecl();
4929+
}
4930+
4931+
return cur;
4932+
}
4933+
4934+
Stmt *VarDecl::getRecursiveParentPatternStmt() const {
4935+
// If our parent is already a pattern stmt, just return that.
4936+
if (auto *stmt = getParentPatternStmt())
4937+
return stmt;
4938+
4939+
// Otherwise, see if we have a parent var decl. If we do not, then return
4940+
// nullptr. Otherwise, return the case stmt that we found.
4941+
auto result = findParentPatternCaseStmtAndPattern(this);
4942+
if (!result.hasValue())
4943+
return nullptr;
4944+
return result->first;
4945+
}
4946+
48874947
/// Return the Pattern involved in initializing this VarDecl. Recall that the
48884948
/// Pattern may be involved in initializing more than just this one vardecl
48894949
/// though. For example, if this is a VarDecl for "x", the pattern may be
@@ -4905,25 +4965,35 @@ Pattern *VarDecl::getParentPattern() const {
49054965

49064966
if (auto *CS = dyn_cast<CatchStmt>(stmt))
49074967
return CS->getErrorPattern();
4908-
4968+
49094969
if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
49104970
// In a case statement, search for the pattern that contains it. This is
49114971
// a bit silly, because you can't have something like "case x, y:" anyway.
49124972
for (auto items : cs->getCaseLabelItems()) {
49134973
if (isVarInPattern(this, items.getPattern()))
49144974
return items.getPattern();
49154975
}
4916-
} else if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
4976+
}
4977+
4978+
if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
49174979
for (auto &elt : LCS->getCond())
49184980
if (auto pat = elt.getPatternOrNull())
49194981
if (isVarInPattern(this, pat))
49204982
return pat;
49214983
}
4922-
4984+
49234985
//stmt->dump();
49244986
assert(0 && "Unknown parent pattern statement?");
49254987
}
4926-
4988+
4989+
// Otherwise, check if we have to walk our case stmt's var decl list to find
4990+
// the pattern.
4991+
if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) {
4992+
return caseStmtPatternPair->second;
4993+
}
4994+
4995+
// Otherwise, this is a case we do not know or understand. Return nullptr to
4996+
// signal we do not have any information.
49274997
return nullptr;
49284998
}
49294999

lib/Sema/MiscDiagnostics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2623,7 +2623,7 @@ VarDeclUsageChecker::~VarDeclUsageChecker() {
26232623
}
26242624
else {
26252625
bool suggestLet = true;
2626-
if (auto *stmt = var->getParentPatternStmt()) {
2626+
if (auto *stmt = var->getRecursiveParentPatternStmt()) {
26272627
// Don't try to suggest 'var' -> 'let' conversion
26282628
// in case of 'for' loop because it's an implicitly
26292629
// immutable context.

0 commit comments

Comments
 (0)