Skip to content

Commit 564b4fc

Browse files
committed
[silgenpattern] Fix some stack-use-after-free errors caused by iterating over an Optional<ArrayRef<T>>.
Specifically the bad pattern was: ``` for (auto *vd : *caseStmt->getCaseBodyVariables()) { ... } ``` The problem is that the optional is not lifetime extended over the for loop. To work around this, I changed the API of CaseStmt's getCaseBodyVariable methods to never return the inner Optional<MutableArrayRef<T>>. Now we have the following 3 methods (ignoring const differences): 1. CaseStmt::hasCaseBodyVariables(). 2. CaseStmt::getCaseBodyVariables(). Asserts if the case body variable array was never specified. 3. CaseStmt::getCaseBodyVariablesOrEmptyArray(). Returns either the case body variables array or an empty array if we were never given any case body variable array. This should prevent anyone else in the future from hitting this type of bug. radar://49609717
1 parent dc7879d commit 564b4fc

File tree

6 files changed

+35
-39
lines changed

6 files changed

+35
-39
lines changed

include/swift/AST/Stmt.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,15 +1055,26 @@ class CaseStmt final
10551055
return UnknownAttrLoc.isValid();
10561056
}
10571057

1058-
Optional<ArrayRef<VarDecl *>> getCaseBodyVariables() const {
1059-
if (!CaseBodyVariables)
1060-
return None;
1058+
/// Return an ArrayRef containing the case body variables of this CaseStmt.
1059+
///
1060+
/// Asserts if case body variables was not explicitly initialized. In contexts
1061+
/// where one wants a non-asserting version, \see
1062+
/// getCaseBodyVariablesOrEmptyArray.
1063+
ArrayRef<VarDecl *> getCaseBodyVariables() const {
10611064
ArrayRef<VarDecl *> a = *CaseBodyVariables;
10621065
return a;
10631066
}
10641067

1065-
Optional<MutableArrayRef<VarDecl *>> getCaseBodyVariables() {
1066-
return CaseBodyVariables;
1068+
bool hasCaseBodyVariables() const { return CaseBodyVariables.hasValue(); }
1069+
1070+
/// Return an MutableArrayRef containing the case body variables of this
1071+
/// CaseStmt.
1072+
///
1073+
/// Asserts if case body variables was not explicitly initialized. In contexts
1074+
/// where one wants a non-asserting version, \see
1075+
/// getCaseBodyVariablesOrEmptyArray.
1076+
MutableArrayRef<VarDecl *> getCaseBodyVariables() {
1077+
return *CaseBodyVariables;
10671078
}
10681079

10691080
ArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() const {

lib/AST/ASTDumper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,13 +1622,13 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
16221622
if (S->hasUnknownAttr())
16231623
OS << " @unknown";
16241624

1625-
if (auto caseBodyVars = S->getCaseBodyVariables()) {
1625+
if (S->hasCaseBodyVariables()) {
16261626
OS << '\n';
16271627
OS.indent(Indent + 2);
16281628
PrintWithColorRAII(OS, ParenthesisColor) << '(';
16291629
PrintWithColorRAII(OS, StmtColor) << "case_body_variables";
16301630
OS << '\n';
1631-
for (auto *vd : *caseBodyVars) {
1631+
for (auto *vd : S->getCaseBodyVariables()) {
16321632
OS.indent(2);
16331633
// TODO: Printing a var decl does an Indent ... dump(vd) ... '\n'. We
16341634
// should see if we can factor this dumping so that the caller of

lib/AST/Decl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5094,7 +5094,7 @@ NullablePtr<VarDecl> VarDecl::getCorrespondingCaseBodyVariable() const {
50945094

50955095
// A var decl associated with a case stmt implies that the case stmt has body
50965096
// var decls. So we can access the optional value here without worry.
5097-
auto caseBodyVars = *caseStmt->getCaseBodyVariables();
5097+
auto caseBodyVars = caseStmt->getCaseBodyVariables();
50985098
auto result = llvm::find_if(caseBodyVars, [&](VarDecl *caseBodyVar) {
50995099
return caseBodyVar->getName() == name;
51005100
});

lib/AST/NameLookup.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,10 +2349,8 @@ void FindLocalVal::visitCaseStmt(CaseStmt *S) {
23492349
}
23502350

23512351
if (!inPatterns && !items.empty()) {
2352-
if (auto caseBodyVars = S->getCaseBodyVariables()) {
2353-
for (auto *vd : *caseBodyVars) {
2354-
checkValueDecl(vd, DeclVisibilityKind::LocalVariable);
2355-
}
2352+
for (auto *vd : S->getCaseBodyVariablesOrEmptyArray()) {
2353+
checkValueDecl(vd, DeclVisibilityKind::LocalVariable);
23562354
}
23572355
}
23582356
visit(S->getBody());

lib/SILGen/SILGenPattern.cpp

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,12 +2332,8 @@ void PatternMatchEmission::initSharedCaseBlockDest(CaseStmt *caseBlock,
23322332
auto *block = SGF.createBasicBlock();
23332333
result.first->second.first = block;
23342334

2335-
// Add args for any pattern variables
2336-
auto caseBodyVars = caseBlock->getCaseBodyVariables();
2337-
if (!caseBodyVars)
2338-
return;
2339-
2340-
for (auto *vd : *caseBodyVars) {
2335+
// Add args for any pattern variables if we have any.
2336+
for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
23412337
if (!vd->hasName())
23422338
continue;
23432339

@@ -2365,14 +2361,10 @@ void PatternMatchEmission::emitAddressOnlyAllocations() {
23652361
for (auto &entry : SharedCases) {
23662362
CaseStmt *caseBlock = entry.first;
23672363

2368-
auto caseBodyVars = caseBlock->getCaseBodyVariables();
2369-
if (!caseBodyVars)
2370-
continue;
2371-
23722364
// If we have a shared case with bound decls, setup the arguments for the
23732365
// shared block by emitting the temporary allocation used for the arguments
23742366
// of the shared block.
2375-
for (auto *vd : *caseBodyVars) {
2367+
for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
23762368
if (!vd->hasName())
23772369
continue;
23782370

@@ -2436,8 +2428,7 @@ void PatternMatchEmission::emitSharedCaseBlocks() {
24362428
assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth);
24372429
SWIFT_DEFER { assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth); };
24382430

2439-
auto caseBodyVars = caseBlock->getCaseBodyVariables();
2440-
if (!caseBodyVars) {
2431+
if (!caseBlock->hasCaseBodyVariables()) {
24412432
emitCaseBody(caseBlock);
24422433
continue;
24432434
}
@@ -2448,7 +2439,7 @@ void PatternMatchEmission::emitSharedCaseBlocks() {
24482439
// args needing Cleanup will get that as well.
24492440
Scope scope(SGF.Cleanups, CleanupLocation(caseBlock));
24502441
unsigned argIndex = 0;
2451-
for (auto *vd : *caseBodyVars) {
2442+
for (auto *vd : caseBlock->getCaseBodyVariables()) {
24522443
if (!vd->hasName())
24532444
continue;
24542445

@@ -2606,14 +2597,14 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
26062597
if (!row.hasFallthroughTo() && caseBlock->getCaseLabelItems().size() == 1) {
26072598
// If we have case body vars, set them up to point at the matching var
26082599
// decls.
2609-
if (auto caseBodyVars = caseBlock->getCaseBodyVariables()) {
2600+
if (caseBlock->hasCaseBodyVariables()) {
26102601
// Since we know that we only have one case label item, grab its pattern
26112602
// vars and use that to update expected with the right SILValue.
26122603
//
26132604
// TODO: Do we need a copy here?
26142605
SmallVector<VarDecl *, 4> patternVars;
26152606
row.getCasePattern()->collectVariables(patternVars);
2616-
for (auto *expected : *caseBodyVars) {
2607+
for (auto *expected : caseBlock->getCaseBodyVariables()) {
26172608
if (!expected->hasName())
26182609
continue;
26192610
for (auto *vd : patternVars) {
@@ -2640,8 +2631,7 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
26402631

26412632
// If we do not have any bound decls, we do not need to setup any
26422633
// variables. Just jump to the shared destination.
2643-
auto caseBodyVars = caseBlock->getCaseBodyVariables();
2644-
if (!caseBodyVars) {
2634+
if (!caseBlock->hasCaseBodyVariables()) {
26452635
// Don't emit anything yet, we emit it at the cleanup level of the switch
26462636
// statement.
26472637
JumpDest sharedDest = emission.getSharedCaseBlockDest(caseBlock);
@@ -2658,7 +2648,7 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
26582648
SILModule &M = SGF.F.getModule();
26592649
SmallVector<VarDecl *, 4> patternVars;
26602650
row.getCasePattern()->collectVariables(patternVars);
2661-
for (auto *expected : *caseBodyVars) {
2651+
for (auto *expected : caseBlock->getCaseBodyVariables()) {
26622652
if (!expected->hasName())
26632653
continue;
26642654
for (auto *var : patternVars) {
@@ -2845,8 +2835,7 @@ void SILGenFunction::emitSwitchFallthrough(FallthroughStmt *S) {
28452835

28462836
// If our destination case doesn't have any bound decls, there is no rebinding
28472837
// to do. Just jump to the shared dest.
2848-
auto destCaseBodyVars = destCaseStmt->getCaseBodyVariables();
2849-
if (!destCaseBodyVars) {
2838+
if (!destCaseStmt->hasCaseBodyVariables()) {
28502839
Cleanups.emitBranchAndCleanups(sharedDest, S);
28512840
return;
28522841
}
@@ -2856,13 +2845,13 @@ void SILGenFunction::emitSwitchFallthrough(FallthroughStmt *S) {
28562845
SmallVector<SILValue, 4> args;
28572846
CaseStmt *fallthroughSourceStmt = S->getFallthroughSource();
28582847

2859-
for (auto *expected : *destCaseBodyVars) {
2848+
for (auto *expected : destCaseStmt->getCaseBodyVariables()) {
28602849
if (!expected->hasName())
28612850
continue;
28622851

28632852
// The type checker enforces that if our destination case has variables then
28642853
// our fallthrough source must as well.
2865-
for (auto *var : *fallthroughSourceStmt->getCaseBodyVariables()) {
2854+
for (auto *var : fallthroughSourceStmt->getCaseBodyVariables()) {
28662855
if (!var->hasName() || var->getName() != expected->getName()) {
28672856
continue;
28682857
}

lib/Sema/MiscDiagnostics.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,10 +2414,8 @@ class VarDeclUsageChecker : public ASTWalker {
24142414

24152415
// Make sure that we setup our case body variables.
24162416
if (auto *caseStmt = dyn_cast<CaseStmt>(S)) {
2417-
if (auto caseBoundDecls = caseStmt->getCaseBodyVariables()) {
2418-
for (auto *vd : *caseBoundDecls) {
2419-
VarDecls[vd] |= 0;
2420-
}
2417+
for (auto *vd : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
2418+
VarDecls[vd] |= 0;
24212419
}
24222420
}
24232421

0 commit comments

Comments
 (0)