Skip to content

[sema] Add the ability for VarDecls to have parent VarDecls and wire up the VarDecl linked list #23378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4577,7 +4577,7 @@ class VarDecl : public AbstractStorageDecl {
};

protected:
llvm::PointerUnion<PatternBindingDecl*, Stmt*> ParentPattern;
PointerUnion3<PatternBindingDecl *, Stmt *, VarDecl *> Parent;

VarDecl(DeclKind Kind, bool IsStatic, Specifier Sp, bool IsCaptureList,
SourceLoc NameLoc, Identifier Name, DeclContext *DC)
Expand Down Expand Up @@ -4648,12 +4648,15 @@ class VarDecl : public AbstractStorageDecl {
/// Return the parent pattern binding that may provide an initializer for this
/// VarDecl. This returns null if there is none associated with the VarDecl.
PatternBindingDecl *getParentPatternBinding() const {
return ParentPattern.dyn_cast<PatternBindingDecl *>();
if (!Parent)
return nullptr;
return Parent.dyn_cast<PatternBindingDecl *>();
}
void setParentPatternBinding(PatternBindingDecl *PBD) {
ParentPattern = PBD;
assert(PBD);
Parent = PBD;
}

/// Return the Pattern involved in initializing this VarDecl. However, recall
/// that the Pattern may be involved in initializing more than just this one
/// vardecl. For example, if this is a VarDecl for "x", the pattern may be
Expand All @@ -4664,16 +4667,53 @@ class VarDecl : public AbstractStorageDecl {
/// returns null.
///
Pattern *getParentPattern() const;

/// Return the statement that owns the pattern associated with this VarDecl,
/// if one exists.
///
/// NOTE: After parsing and before type checking, all VarDecls from
/// CaseLabelItem's Patterns return their CaseStmt. After type checking, we
/// will have constructed the CaseLabelItem VarDecl linked list implying this
/// will return nullptr. After type checking, if one wishes to find a parent
/// pattern of a VarDecl of a CaseStmt, \see getRecursiveParentPatternStmt
/// instead.
Stmt *getParentPatternStmt() const {
return ParentPattern.dyn_cast<Stmt*>();
if (!Parent)
return nullptr;
return Parent.dyn_cast<Stmt *>();
}
void setParentPatternStmt(Stmt *S) {
ParentPattern = S;

void setParentPatternStmt(Stmt *s) {
assert(s);
Parent = s;
}

/// Look for the parent pattern stmt of this var decl, recursively
/// looking through var decl pointers and then through any
/// fallthroughts.
Stmt *getRecursiveParentPatternStmt() const;

/// Returns the var decl that this var decl is an implicit reference to if
/// such a var decl exists.
VarDecl *getParentVarDecl() const {
if (!Parent)
return nullptr;
return Parent.dyn_cast<VarDecl *>();
}

/// Set \p v to be the pattern produced VarDecl that is the parent of this
/// var decl.
void setParentVarDecl(VarDecl *v) {
assert(v);
Parent = v;
}

/// If this is a VarDecl that does not belong to a CaseLabelItem's pattern,
/// return this. Otherwise, this VarDecl must belong to a CaseStmt's
/// CaseLabelItem. In that case, return the first case label item of the first
/// case stmt in a sequence of case stmts that fallthrough into each other.
VarDecl *getCanonicalVarDecl() const;

/// True if the global stored property requires lazy initialization.
bool isLazilyInitializedGlobal() const;

Expand Down
22 changes: 12 additions & 10 deletions include/swift/Basic/LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace llvm {
template<typename T> class TinyPtrVector;
template<typename T> class Optional;
template <typename PT1, typename PT2> class PointerUnion;
template <typename PT1, typename PT2, typename PT3> class PointerUnion3;
class SmallBitVector;

// Other common classes.
Expand All @@ -62,22 +63,23 @@ namespace swift {
using llvm::cast_or_null;

// Containers.
using llvm::ArrayRef;
using llvm::MutableArrayRef;
using llvm::None;
using llvm::Optional;
using llvm::SmallPtrSetImpl;
using llvm::PointerUnion;
using llvm::PointerUnion3;
using llvm::SmallBitVector;
using llvm::SmallPtrSet;
using llvm::SmallPtrSetImpl;
using llvm::SmallSetVector;
using llvm::SmallString;
using llvm::StringRef;
using llvm::StringLiteral;
using llvm::Twine;
using llvm::SmallVectorImpl;
using llvm::SmallVector;
using llvm::ArrayRef;
using llvm::MutableArrayRef;
using llvm::SmallVectorImpl;
using llvm::StringLiteral;
using llvm::StringRef;
using llvm::TinyPtrVector;
using llvm::PointerUnion;
using llvm::SmallSetVector;
using llvm::SmallBitVector;
using llvm::Twine;

// Other common classes.
using llvm::APFloat;
Expand Down
86 changes: 78 additions & 8 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4874,14 +4874,74 @@ SourceRange VarDecl::getTypeSourceRangeForDiagnostics() const {
return SourceRange();
}

static bool isVarInPattern(const VarDecl *VD, Pattern *P) {
static bool isVarInPattern(const VarDecl *vd, Pattern *p) {
bool foundIt = false;
P->forEachVariable([&](VarDecl *FoundFD) {
foundIt |= FoundFD == VD;
});
p->forEachVariable([&](VarDecl *foundFD) { foundIt |= foundFD == vd; });
return foundIt;
}

static Optional<std::pair<CaseStmt *, Pattern *>>
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
for (auto &item : cs->getMutableCaseLabelItems()) {
if (isVarInPattern(inputVD, item.getPattern())) {
return item.getPattern();
}
}
return nullptr;
};

// First find our canonical var decl. This is the VarDecl corresponding to the
// first case label item of the first case block in the fallthrough chain that
// our case block is within. Grab the case stmt associated with that var decl
// and start traveling down the fallthrough chain looking for the case
// statement that the input VD belongs to by using getMatchingPattern().
auto *canonicalVD = inputVD->getCanonicalVarDecl();
auto *caseStmt =
dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt());
if (!caseStmt)
return None;

if (auto *p = getMatchingPattern(caseStmt))
return std::make_pair(caseStmt, p);

while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) {
if (auto *p = getMatchingPattern(caseStmt))
return std::make_pair(caseStmt, p);
}

return None;
}

VarDecl *VarDecl::getCanonicalVarDecl() const {
// Any var decl without a parent var decl is canonical. This means that before
// type checking, all var decls are canonical.
auto *cur = const_cast<VarDecl *>(this);
auto *vd = cur->getParentVarDecl();
if (!vd)
return cur;

while (vd) {
cur = vd;
vd = vd->getParentVarDecl();
}

return cur;
}

Stmt *VarDecl::getRecursiveParentPatternStmt() const {
// If our parent is already a pattern stmt, just return that.
if (auto *stmt = getParentPatternStmt())
return stmt;

// Otherwise, see if we have a parent var decl. If we do not, then return
// nullptr. Otherwise, return the case stmt that we found.
auto result = findParentPatternCaseStmtAndPattern(this);
if (!result.hasValue())
return nullptr;
return result->first;
}

/// Return the Pattern involved in initializing this VarDecl. Recall that the
/// Pattern may be involved in initializing more than just this one vardecl
/// though. For example, if this is a VarDecl for "x", the pattern may be
Expand All @@ -4903,25 +4963,35 @@ Pattern *VarDecl::getParentPattern() const {

if (auto *CS = dyn_cast<CatchStmt>(stmt))
return CS->getErrorPattern();

if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
// In a case statement, search for the pattern that contains it. This is
// a bit silly, because you can't have something like "case x, y:" anyway.
for (auto items : cs->getCaseLabelItems()) {
if (isVarInPattern(this, items.getPattern()))
return items.getPattern();
}
} else if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
}

if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
for (auto &elt : LCS->getCond())
if (auto pat = elt.getPatternOrNull())
if (isVarInPattern(this, pat))
return pat;
}

//stmt->dump();
assert(0 && "Unknown parent pattern statement?");
}


// Otherwise, check if we have to walk our case stmt's var decl list to find
// the pattern.
if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) {
return caseStmtPatternPair->second;
}

// Otherwise, this is a case we do not know or understand. Return nullptr to
// signal we do not have any information.
return nullptr;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/MiscDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2623,7 +2623,7 @@ VarDeclUsageChecker::~VarDeclUsageChecker() {
}
else {
bool suggestLet = true;
if (auto *stmt = var->getParentPatternStmt()) {
if (auto *stmt = var->getRecursiveParentPatternStmt()) {
// Don't try to suggest 'var' -> 'let' conversion
// in case of 'for' loop because it's an implicitly
// immutable context.
Expand Down
Loading