Skip to content

Commit e0785bc

Browse files
committed
[AST] 'CaseStmt' to hold the parent 'SwitchStmt' or 'DoCatchStmt'
1 parent c17966e commit e0785bc

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

include/swift/AST/Stmt.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class Pattern;
3939
class PatternBindingDecl;
4040
class VarDecl;
4141
class CaseStmt;
42+
class DoCatchStmt;
43+
class SwitchStmt;
4244

4345
enum class StmtKind {
4446
#define STMT(ID, PARENT) ID,
@@ -927,6 +929,7 @@ class CaseStmt final
927929
CaseLabelItem> {
928930
friend TrailingObjects;
929931

932+
Stmt *ParentStmt = nullptr;
930933
SourceLoc UnknownAttrLoc;
931934
SourceLoc ItemIntroducerLoc;
932935
SourceLoc ItemTerminatorLoc;
@@ -954,6 +957,14 @@ class CaseStmt final
954957

955958
CaseParentKind getParentKind() const { return ParentKind; }
956959

960+
Stmt *getParentStmt() const { return ParentStmt; }
961+
void setParentStmt(Stmt *S) {
962+
assert(S && "Parent statement must be SwitchStmt or DoCatchStmt");
963+
assert((ParentKind == CaseParentKind::Switch && isa<SwitchStmt>(S)) ||
964+
(ParentKind == CaseParentKind::DoCatch && isa<DoCatchStmt>(S)));
965+
ParentStmt = S;
966+
}
967+
957968
ArrayRef<CaseLabelItem> getCaseLabelItems() const {
958969
return {getTrailingObjects<CaseLabelItem>(), Bits.CaseStmt.NumPatterns};
959970
}
@@ -1161,6 +1172,8 @@ class DoCatchStmt final
11611172
Bits.DoCatchStmt.NumCatches = catches.size();
11621173
std::uninitialized_copy(catches.begin(), catches.end(),
11631174
getTrailingObjects<CaseStmt *>());
1175+
for (auto *catchStmt : getCatches())
1176+
catchStmt->setParentStmt(this);
11641177
}
11651178

11661179
public:

lib/AST/Stmt.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
479479

480480
std::uninitialized_copy(Cases.begin(), Cases.end(),
481481
theSwitch->getTrailingObjects<ASTNode>());
482+
for (auto *caseStmt : theSwitch->getCases())
483+
caseStmt->setParentStmt(theSwitch);
484+
482485
return theSwitch;
483486
}
484487

0 commit comments

Comments
 (0)