Skip to content

[silgenpattern] Fix two ASAN errors. #23801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1055,15 +1055,26 @@ class CaseStmt final
return UnknownAttrLoc.isValid();
}

Optional<ArrayRef<VarDecl *>> getCaseBodyVariables() const {
if (!CaseBodyVariables)
return None;
/// Return an ArrayRef containing the case body variables of this CaseStmt.
///
/// Asserts if case body variables was not explicitly initialized. In contexts
/// where one wants a non-asserting version, \see
/// getCaseBodyVariablesOrEmptyArray.
ArrayRef<VarDecl *> getCaseBodyVariables() const {
ArrayRef<VarDecl *> a = *CaseBodyVariables;
return a;
}

Optional<MutableArrayRef<VarDecl *>> getCaseBodyVariables() {
return CaseBodyVariables;
bool hasCaseBodyVariables() const { return CaseBodyVariables.hasValue(); }

/// Return an MutableArrayRef containing the case body variables of this
/// CaseStmt.
///
/// Asserts if case body variables was not explicitly initialized. In contexts
/// where one wants a non-asserting version, \see
/// getCaseBodyVariablesOrEmptyArray.
MutableArrayRef<VarDecl *> getCaseBodyVariables() {
return *CaseBodyVariables;
}

ArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() const {
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1622,13 +1622,13 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
if (S->hasUnknownAttr())
OS << " @unknown";

if (auto caseBodyVars = S->getCaseBodyVariables()) {
if (S->hasCaseBodyVariables()) {
OS << '\n';
OS.indent(Indent + 2);
PrintWithColorRAII(OS, ParenthesisColor) << '(';
PrintWithColorRAII(OS, StmtColor) << "case_body_variables";
OS << '\n';
for (auto *vd : *caseBodyVars) {
for (auto *vd : S->getCaseBodyVariables()) {
OS.indent(2);
// TODO: Printing a var decl does an Indent ... dump(vd) ... '\n'. We
// should see if we can factor this dumping so that the caller of
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5094,7 +5094,7 @@ NullablePtr<VarDecl> VarDecl::getCorrespondingCaseBodyVariable() const {

// A var decl associated with a case stmt implies that the case stmt has body
// var decls. So we can access the optional value here without worry.
auto caseBodyVars = *caseStmt->getCaseBodyVariables();
auto caseBodyVars = caseStmt->getCaseBodyVariables();
auto result = llvm::find_if(caseBodyVars, [&](VarDecl *caseBodyVar) {
return caseBodyVar->getName() == name;
});
Expand Down
6 changes: 2 additions & 4 deletions lib/AST/NameLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2349,10 +2349,8 @@ void FindLocalVal::visitCaseStmt(CaseStmt *S) {
}

if (!inPatterns && !items.empty()) {
if (auto caseBodyVars = S->getCaseBodyVariables()) {
for (auto *vd : *caseBodyVars) {
checkValueDecl(vd, DeclVisibilityKind::LocalVariable);
}
for (auto *vd : S->getCaseBodyVariablesOrEmptyArray()) {
checkValueDecl(vd, DeclVisibilityKind::LocalVariable);
}
}
visit(S->getBody());
Expand Down
38 changes: 14 additions & 24 deletions lib/SILGen/SILGenPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2332,12 +2332,8 @@ void PatternMatchEmission::initSharedCaseBlockDest(CaseStmt *caseBlock,
auto *block = SGF.createBasicBlock();
result.first->second.first = block;

// Add args for any pattern variables
auto caseBodyVars = caseBlock->getCaseBodyVariables();
if (!caseBodyVars)
return;

for (auto *vd : *caseBodyVars) {
// Add args for any pattern variables if we have any.
for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
if (!vd->hasName())
continue;

Expand Down Expand Up @@ -2365,14 +2361,10 @@ void PatternMatchEmission::emitAddressOnlyAllocations() {
for (auto &entry : SharedCases) {
CaseStmt *caseBlock = entry.first;

auto caseBodyVars = caseBlock->getCaseBodyVariables();
if (!caseBodyVars)
continue;

// If we have a shared case with bound decls, setup the arguments for the
// shared block by emitting the temporary allocation used for the arguments
// of the shared block.
for (auto *vd : *caseBodyVars) {
for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
if (!vd->hasName())
continue;

Expand Down Expand Up @@ -2436,8 +2428,7 @@ void PatternMatchEmission::emitSharedCaseBlocks() {
assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth);
SWIFT_DEFER { assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth); };

auto caseBodyVars = caseBlock->getCaseBodyVariables();
if (!caseBodyVars) {
if (!caseBlock->hasCaseBodyVariables()) {
emitCaseBody(caseBlock);
continue;
}
Expand All @@ -2448,7 +2439,7 @@ void PatternMatchEmission::emitSharedCaseBlocks() {
// args needing Cleanup will get that as well.
Scope scope(SGF.Cleanups, CleanupLocation(caseBlock));
unsigned argIndex = 0;
for (auto *vd : *caseBodyVars) {
for (auto *vd : caseBlock->getCaseBodyVariables()) {
if (!vd->hasName())
continue;

Expand Down Expand Up @@ -2606,14 +2597,14 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
if (!row.hasFallthroughTo() && caseBlock->getCaseLabelItems().size() == 1) {
// If we have case body vars, set them up to point at the matching var
// decls.
if (auto caseBodyVars = caseBlock->getCaseBodyVariables()) {
if (caseBlock->hasCaseBodyVariables()) {
// Since we know that we only have one case label item, grab its pattern
// vars and use that to update expected with the right SILValue.
//
// TODO: Do we need a copy here?
SmallVector<VarDecl *, 4> patternVars;
row.getCasePattern()->collectVariables(patternVars);
for (auto *expected : *caseBodyVars) {
for (auto *expected : caseBlock->getCaseBodyVariables()) {
if (!expected->hasName())
continue;
for (auto *vd : patternVars) {
Expand All @@ -2622,7 +2613,8 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
}

// Ok, we found a match. Update the VarLocs for the case block.
SGF.VarLocs[expected] = SGF.VarLocs[vd];
auto v = SGF.VarLocs[vd];
SGF.VarLocs[expected] = v;
}
}
}
Expand All @@ -2639,8 +2631,7 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,

// If we do not have any bound decls, we do not need to setup any
// variables. Just jump to the shared destination.
auto caseBodyVars = caseBlock->getCaseBodyVariables();
if (!caseBodyVars) {
if (!caseBlock->hasCaseBodyVariables()) {
// Don't emit anything yet, we emit it at the cleanup level of the switch
// statement.
JumpDest sharedDest = emission.getSharedCaseBlockDest(caseBlock);
Expand All @@ -2657,7 +2648,7 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
SILModule &M = SGF.F.getModule();
SmallVector<VarDecl *, 4> patternVars;
row.getCasePattern()->collectVariables(patternVars);
for (auto *expected : *caseBodyVars) {
for (auto *expected : caseBlock->getCaseBodyVariables()) {
if (!expected->hasName())
continue;
for (auto *var : patternVars) {
Expand Down Expand Up @@ -2844,8 +2835,7 @@ void SILGenFunction::emitSwitchFallthrough(FallthroughStmt *S) {

// If our destination case doesn't have any bound decls, there is no rebinding
// to do. Just jump to the shared dest.
auto destCaseBodyVars = destCaseStmt->getCaseBodyVariables();
if (!destCaseBodyVars) {
if (!destCaseStmt->hasCaseBodyVariables()) {
Cleanups.emitBranchAndCleanups(sharedDest, S);
return;
}
Expand All @@ -2855,13 +2845,13 @@ void SILGenFunction::emitSwitchFallthrough(FallthroughStmt *S) {
SmallVector<SILValue, 4> args;
CaseStmt *fallthroughSourceStmt = S->getFallthroughSource();

for (auto *expected : *destCaseBodyVars) {
for (auto *expected : destCaseStmt->getCaseBodyVariables()) {
if (!expected->hasName())
continue;

// The type checker enforces that if our destination case has variables then
// our fallthrough source must as well.
for (auto *var : *fallthroughSourceStmt->getCaseBodyVariables()) {
for (auto *var : fallthroughSourceStmt->getCaseBodyVariables()) {
if (!var->hasName() || var->getName() != expected->getName()) {
continue;
}
Expand Down
6 changes: 2 additions & 4 deletions lib/Sema/MiscDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2414,10 +2414,8 @@ class VarDeclUsageChecker : public ASTWalker {

// Make sure that we setup our case body variables.
if (auto *caseStmt = dyn_cast<CaseStmt>(S)) {
if (auto caseBoundDecls = caseStmt->getCaseBodyVariables()) {
for (auto *vd : *caseBoundDecls) {
VarDecls[vd] |= 0;
}
for (auto *vd : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
VarDecls[vd] |= 0;
}
}

Expand Down
83 changes: 83 additions & 0 deletions test/SILGen/switch.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1500,3 +1500,86 @@ func nonTrivialLoadableFallthroughCallee2(_ e : MultipleNonTrivialCaseEnum) {
}
}

// Make sure that we do not crash while emitting this code.
//
// DISCUSSION: The original crash was due to us performing an assignment/lookup
// on the VarLocs DenseMap in the same statement. This was caught be an
// asanified compiler. This test is just to make sure we do not regress.
enum Storage {
case empty
case single(Int)
case pair(Int, Int)
case array([Int])

subscript(range: [Int]) -> Storage {
get {
return .empty
}
set {
switch self {
case .empty:
break
case .single(let index):
break
case .pair(let first, let second):
switch (range[0], range[1]) {
case (0, 0):
switch newValue {
case .empty:
break
case .single(let other):
break
case .pair(let otherFirst, let otherSecond):
break
case .array(let other):
break
}
break
case (0, 1):
switch newValue {
case .empty:
break
case .single(let other):
break
case .pair(let otherFirst, let otherSecond):
break
case .array(let other):
break
}
break
case (0, 2):
break
case (1, 2):
switch newValue {
case .empty:
break
case .single(let other):
break
case .pair(let otherFirst, let otherSecond):
break
case .array(let other):
self = .array([first] + other)
}
break
case (2, 2):
switch newValue {
case .empty:
break
case .single(let other):
break
case .pair(let otherFirst, let otherSecond):
break
case .array(let other):
self = .array([first, second] + other)
}
break
default:
let r = range
}
case .array(let indexes):
break
}
}
}
}