Skip to content

Commit 8166ad4

Browse files
authored
Merge pull request #41343 from xedin/se-326-conditions-solving
[CSClosure] SE-0326: Type-checker statement conditions individually
2 parents 352b3a2 + 0cc8bc7 commit 8166ad4

File tree

7 files changed

+44
-45
lines changed

7 files changed

+44
-45
lines changed

include/swift/AST/ASTNode.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@ namespace swift {
4343
enum class PatternKind : uint8_t;
4444
enum class StmtKind;
4545

46-
using StmtCondition = llvm::MutableArrayRef<StmtConditionElement>;
47-
4846
struct ASTNode
4947
: public llvm::PointerUnion<Expr *, Stmt *, Decl *, Pattern *, TypeRepr *,
50-
StmtCondition *, CaseLabelItem *> {
48+
StmtConditionElement *, CaseLabelItem *> {
5149
// Inherit the constructors from PointerUnion.
5250
using PointerUnion::PointerUnion;
5351

include/swift/AST/Stmt.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ class alignas(8) PoundAvailableInfo final :
395395
/// the "x" binding, one for the "y" binding, one for the where clause, one for
396396
/// "z"'s binding. A simple "if" statement is represented as a single binding.
397397
///
398-
class StmtConditionElement {
398+
class alignas(1 << PatternAlignInBits) StmtConditionElement {
399399
/// If this is a pattern binding, it may be the first one in a declaration, in
400400
/// which case this is the location of the var/let/case keyword. If this is
401401
/// the second pattern (e.g. for 'y' in "var x = ..., y = ...") then this
@@ -818,7 +818,7 @@ class ForEachStmt : public LabeledStmt {
818818
};
819819

820820
/// A pattern and an optional guard expression used in a 'case' statement.
821-
class CaseLabelItem {
821+
class alignas(1 << PatternAlignInBits) CaseLabelItem {
822822
enum class Kind {
823823
/// A normal pattern
824824
Normal = 0,

include/swift/AST/TypeAlignments.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ namespace swift {
6161
class TypeRepr;
6262
class ValueDecl;
6363
class CaseLabelItem;
64+
class StmtConditionElement;
6465

6566
/// We frequently use three tag bits on all of these types.
6667
constexpr size_t AttrAlignInBits = 3;
@@ -155,6 +156,9 @@ LLVM_DECLARE_TYPE_ALIGNMENT(swift::TypeRepr, swift::TypeReprAlignInBits)
155156

156157
LLVM_DECLARE_TYPE_ALIGNMENT(swift::CaseLabelItem, swift::PatternAlignInBits)
157158

159+
LLVM_DECLARE_TYPE_ALIGNMENT(swift::StmtConditionElement,
160+
swift::PatternAlignInBits)
161+
158162
static_assert(alignof(void*) >= 2, "pointer alignment is too small");
159163

160164
#endif

include/swift/Sema/ConstraintLocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ class LocatorPathElt::ClosureBodyElement final
10361036
if (auto *repr = node.dyn_cast<TypeRepr *>())
10371037
return repr;
10381038

1039-
if (auto *cond = node.dyn_cast<StmtCondition *>())
1039+
if (auto *cond = node.dyn_cast<StmtConditionElement *>())
10401040
return cond;
10411041

10421042
if (auto *caseItem = node.dyn_cast<CaseLabelItem *>())

lib/AST/ASTNode.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,8 @@ SourceRange ASTNode::getSourceRange() const {
3535
return P->getSourceRange();
3636
if (const auto *T = this->dyn_cast<TypeRepr *>())
3737
return T->getSourceRange();
38-
if (const auto *C = this->dyn_cast<StmtCondition *>()) {
39-
if (C->empty())
40-
return SourceRange();
41-
42-
auto first = C->front();
43-
auto last = C->back();
44-
45-
return {first.getStartLoc(), last.getEndLoc()};
46-
}
38+
if (const auto *C = this->dyn_cast<StmtConditionElement *>())
39+
return C->getSourceRange();
4740
if (const auto *I = this->dyn_cast<CaseLabelItem *>()) {
4841
return I->getSourceRange();
4942
}
@@ -85,7 +78,7 @@ bool ASTNode::isImplicit() const {
8578
return P->isImplicit();
8679
if (const auto *T = this->dyn_cast<TypeRepr*>())
8780
return false;
88-
if (const auto *C = this->dyn_cast<StmtCondition *>())
81+
if (const auto *C = this->dyn_cast<StmtConditionElement *>())
8982
return false;
9083
if (const auto *I = this->dyn_cast<CaseLabelItem *>())
9184
return false;
@@ -103,10 +96,9 @@ void ASTNode::walk(ASTWalker &Walker) {
10396
P->walk(Walker);
10497
else if (auto *T = this->dyn_cast<TypeRepr*>())
10598
T->walk(Walker);
106-
else if (auto *C = this->dyn_cast<StmtCondition *>()) {
107-
for (auto &elt : *C)
108-
elt.walk(Walker);
109-
} else if (auto *I = this->dyn_cast<CaseLabelItem *>()) {
99+
else if (auto *C = this->dyn_cast<StmtConditionElement *>())
100+
C->walk(Walker);
101+
else if (auto *I = this->dyn_cast<CaseLabelItem *>()) {
110102
if (auto *P = I->getPattern())
111103
P->walk(Walker);
112104

@@ -127,9 +119,9 @@ void ASTNode::dump(raw_ostream &OS, unsigned Indent) const {
127119
P->dump(OS, Indent);
128120
else if (auto T = dyn_cast<TypeRepr*>())
129121
T->print(OS);
130-
else if (auto C = dyn_cast<StmtCondition *>()) {
131-
OS.indent(Indent) << "(statement conditions)";
132-
} else if (auto *I = dyn_cast<CaseLabelItem *>()) {
122+
else if (auto *C = dyn_cast<StmtConditionElement *>())
123+
OS.indent(Indent) << "(statement condition)";
124+
else if (auto *I = dyn_cast<CaseLabelItem *>()) {
133125
OS.indent(Indent) << "(case label item)";
134126
} else
135127
llvm_unreachable("unsupported AST node");

lib/Sema/CSClosure.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -531,18 +531,23 @@ class ClosureConstraintGenerator
531531
"Unsupported statement: Fallthrough");
532532
}
533533

534+
void visitStmtCondition(LabeledConditionalStmt *S,
535+
SmallVectorImpl<ElementInfo> &elements,
536+
ConstraintLocator *locator) {
537+
auto *condLocator =
538+
cs.getConstraintLocator(locator, ConstraintLocator::Condition);
539+
for (auto &condition : S->getCond())
540+
elements.push_back(makeElement(&condition, condLocator));
541+
}
542+
534543
void visitIfStmt(IfStmt *ifStmt) {
535544
assert(isSupportedMultiStatementClosure() &&
536545
"Unsupported statement: If");
537546

538547
SmallVector<ElementInfo, 4> elements;
539548

540549
// Condition
541-
{
542-
auto *condLoc =
543-
cs.getConstraintLocator(locator, ConstraintLocator::Condition);
544-
elements.push_back(makeElement(ifStmt->getCondPointer(), condLoc));
545-
}
550+
visitStmtCondition(ifStmt, elements, locator);
546551

547552
// Then Branch
548553
{
@@ -565,24 +570,24 @@ class ClosureConstraintGenerator
565570
assert(isSupportedMultiStatementClosure() &&
566571
"Unsupported statement: Guard");
567572

568-
createConjunction(cs,
569-
{makeElement(guardStmt->getCondPointer(),
570-
cs.getConstraintLocator(
571-
locator, ConstraintLocator::Condition)),
572-
makeElement(guardStmt->getBody(), locator)},
573-
locator);
573+
SmallVector<ElementInfo, 4> elements;
574+
575+
visitStmtCondition(guardStmt, elements, locator);
576+
elements.push_back(makeElement(guardStmt->getBody(), locator));
577+
578+
createConjunction(cs, elements, locator);
574579
}
575580

576581
void visitWhileStmt(WhileStmt *whileStmt) {
577582
assert(isSupportedMultiStatementClosure() &&
578583
"Unsupported statement: While");
579584

580-
createConjunction(cs,
581-
{makeElement(whileStmt->getCondPointer(),
582-
cs.getConstraintLocator(
583-
locator, ConstraintLocator::Condition)),
584-
makeElement(whileStmt->getBody(), locator)},
585-
locator);
585+
SmallVector<ElementInfo, 4> elements;
586+
587+
visitStmtCondition(whileStmt, elements, locator);
588+
elements.push_back(makeElement(whileStmt->getBody(), locator));
589+
590+
createConjunction(cs, elements, locator);
586591
}
587592

588593
void visitDoStmt(DoStmt *doStmt) {
@@ -970,8 +975,8 @@ ConstraintSystem::simplifyClosureBodyElementConstraint(
970975
return SolutionKind::Solved;
971976
} else if (auto *stmt = element.dyn_cast<Stmt *>()) {
972977
generator.visit(stmt);
973-
} else if (auto *cond = element.dyn_cast<StmtCondition *>()) {
974-
if (generateConstraints(*cond, closure))
978+
} else if (auto *cond = element.dyn_cast<StmtConditionElement *>()) {
979+
if (generateConstraints({*cond}, closure))
975980
return SolutionKind::Error;
976981
} else if (auto *pattern = element.dyn_cast<Pattern *>()) {
977982
generator.visitPattern(pattern, context);
@@ -1571,7 +1576,7 @@ void ConjunctionElement::findReferencedVariables(
15711576

15721577
TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars);
15731578

1574-
if (element.is<Decl *>() || element.is<StmtCondition *>() ||
1579+
if (element.is<Decl *>() || element.is<StmtConditionElement *>() ||
15751580
element.is<Expr *>() || element.isStmt(StmtKind::Return))
15761581
element.walk(refFinder);
15771582
}

lib/Sema/ConstraintSystem.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6050,8 +6050,8 @@ SourceLoc constraints::getLoc(ASTNode anchor) {
60506050
return S->getStartLoc();
60516051
} else if (auto *P = anchor.dyn_cast<Pattern *>()) {
60526052
return P->getLoc();
6053-
} else if (auto *C = anchor.dyn_cast<StmtCondition *>()) {
6054-
return C->front().getStartLoc();
6053+
} else if (auto *C = anchor.dyn_cast<StmtConditionElement *>()) {
6054+
return C->getStartLoc();
60556055
} else {
60566056
auto *I = anchor.get<CaseLabelItem *>();
60576057
return I->getStartLoc();

0 commit comments

Comments
 (0)