Skip to content

Commit a35f48c

Browse files
committed
walk
1 parent e8c81fb commit a35f48c

File tree

2 files changed

+73
-39
lines changed

2 files changed

+73
-39
lines changed

include/swift/Sema/SolutionApplicationTarget.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ class SolutionApplicationTarget {
869869
}
870870

871871
/// Walk the contents of the application target.
872-
SolutionApplicationTarget walk(ASTWalker &walker) const;
872+
Optional<SolutionApplicationTarget> walk(ASTWalker &walker) const;
873873
};
874874

875875
}

lib/Sema/SolutionApplicationTarget.cpp

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -281,50 +281,84 @@ bool SolutionApplicationTarget::contextualTypeIsOnlyAHint() const {
281281
llvm_unreachable("invalid contextual type");
282282
}
283283

284-
SolutionApplicationTarget SolutionApplicationTarget::walk(ASTWalker &walker) {
284+
Optional<SolutionApplicationTarget>
285+
SolutionApplicationTarget::walk(ASTWalker &walker) const {
286+
SolutionApplicationTarget result = *this;
285287
switch (kind) {
286288
case Kind::expression: {
287-
SolutionApplicationTarget result = *this;
288-
result.setExpr(getAsExpr()->walk(walker));
289-
return result;
289+
if (isForInitialization()) {
290+
if (auto *newPattern = getInitializationPattern()->walk(walker)) {
291+
result.setPattern(newPattern);
292+
} else {
293+
return None;
294+
}
295+
}
296+
if (auto *newExpr = getAsExpr()->walk(walker)) {
297+
result.setExpr(newExpr);
298+
} else {
299+
return None;
300+
}
301+
break;
290302
}
291-
292-
case Kind::closure:
293-
return *this;
294-
295-
case Kind::function:
296-
return SolutionApplicationTarget(
297-
*getAsFunction(),
298-
cast_or_null<BraceStmt>(getFunctionBody()->walk(walker)));
299-
300-
case Kind::stmtCondition:
301-
for (auto &condElement : stmtCondition.stmtCondition) {
302-
condElement = *condElement.walk(walker);
303+
case Kind::closure: {
304+
if (auto *newClosure = closure.closure->walk(walker)) {
305+
result.closure.closure = cast<ClosureExpr>(newClosure);
306+
} else {
307+
return None;
308+
}
309+
break;
310+
}
311+
case Kind::function: {
312+
if (auto *newBody = getFunctionBody()->walk(walker)) {
313+
result.function.body = cast<BraceStmt>(newBody);
314+
} else {
315+
return None;
303316
}
304-
return *this;
317+
break;
318+
}
319+
case Kind::stmtCondition: {
320+
for (auto &condElement : stmtCondition.stmtCondition)
321+
condElement = *condElement.walk(walker);
305322

306-
case Kind::caseLabelItem:
307-
if (auto newPattern =
308-
caseLabelItem.caseLabelItem->getPattern()->walk(walker)) {
309-
caseLabelItem.caseLabelItem->setPattern(
310-
newPattern, caseLabelItem.caseLabelItem->isPatternResolved());
323+
break;
324+
}
325+
case Kind::caseLabelItem: {
326+
auto *item = caseLabelItem.caseLabelItem;
327+
if (auto *newPattern = item->getPattern()->walk(walker)) {
328+
item->setPattern(newPattern, item->isPatternResolved());
329+
} else {
330+
return None;
311331
}
312-
if (auto guardExpr = caseLabelItem.caseLabelItem->getGuardExpr()) {
313-
if (auto newGuardExpr = guardExpr->walk(walker))
314-
caseLabelItem.caseLabelItem->setGuardExpr(newGuardExpr);
332+
if (auto guardExpr = item->getGuardExpr()) {
333+
if (auto newGuardExpr = guardExpr->walk(walker)) {
334+
item->setGuardExpr(newGuardExpr);
335+
} else {
336+
return None;
337+
}
315338
}
316-
317-
return *this;
318-
319-
case Kind::patternBinding:
320-
return *this;
321-
322-
case Kind::uninitializedVar:
323-
return *this;
324-
325-
case Kind::forEachStmt:
326-
return *this;
339+
break;
327340
}
328-
329-
llvm_unreachable("invalid target kind");
341+
case Kind::patternBinding: {
342+
if (getAsPatternBinding()->walk(walker))
343+
return None;
344+
break;
345+
}
346+
case Kind::uninitializedVar: {
347+
if (auto *P = getAsUninitializedVar()->walk(walker)) {
348+
result.setPattern(P);
349+
} else {
350+
return None;
351+
}
352+
break;
353+
}
354+
case Kind::forEachStmt: {
355+
if (auto *newStmt = getAsForEachStmt()->walk(walker)) {
356+
result.forEachStmt.stmt = cast<ForEachStmt>(newStmt);
357+
} else {
358+
return None;
359+
}
360+
break;
361+
}
362+
}
363+
return result;
330364
}

0 commit comments

Comments
 (0)