@@ -39,6 +39,8 @@ class Pattern;
39
39
class PatternBindingDecl ;
40
40
class VarDecl ;
41
41
class CaseStmt ;
42
+ class DoCatchStmt ;
43
+ class SwitchStmt ;
42
44
43
45
enum class StmtKind {
44
46
#define STMT (ID, PARENT ) ID,
@@ -927,6 +929,7 @@ class CaseStmt final
927
929
CaseLabelItem> {
928
930
friend TrailingObjects;
929
931
932
+ Stmt *ParentStmt = nullptr ;
930
933
SourceLoc UnknownAttrLoc;
931
934
SourceLoc ItemIntroducerLoc;
932
935
SourceLoc ItemTerminatorLoc;
@@ -954,6 +957,14 @@ class CaseStmt final
954
957
955
958
CaseParentKind getParentKind () const { return ParentKind; }
956
959
960
+ Stmt *getParentStmt () const { return ParentStmt; }
961
+ void setParentStmt (Stmt *S) {
962
+ assert (S && " Parent statement must be SwitchStmt or DoCatchStmt" );
963
+ assert ((ParentKind == CaseParentKind::Switch && isa<SwitchStmt>(S)) ||
964
+ (ParentKind == CaseParentKind::DoCatch && isa<DoCatchStmt>(S)));
965
+ ParentStmt = S;
966
+ }
967
+
957
968
ArrayRef<CaseLabelItem> getCaseLabelItems () const {
958
969
return {getTrailingObjects<CaseLabelItem>(), Bits.CaseStmt .NumPatterns };
959
970
}
@@ -1161,6 +1172,8 @@ class DoCatchStmt final
1161
1172
Bits.DoCatchStmt .NumCatches = catches.size ();
1162
1173
std::uninitialized_copy (catches.begin (), catches.end (),
1163
1174
getTrailingObjects<CaseStmt *>());
1175
+ for (auto *catchStmt : getCatches ())
1176
+ catchStmt->setParentStmt (this );
1164
1177
}
1165
1178
1166
1179
public:
0 commit comments