Skip to content

Commit 3338b03

Browse files
authored
[CS] Connect closure to referenced vars (#31304)
[CS] Connect closure to referenced vars
2 parents 428d58e + 2070b2c commit 3338b03

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

lib/Sema/CSGen.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2740,16 +2740,19 @@ namespace {
27402740
auto *locator = CS.getConstraintLocator(closure);
27412741
auto closureType = CS.createTypeVariable(locator, TVO_CanBindToNoEscape);
27422742

2743-
// Collect any references to closure parameters whose types involve type
2744-
// variables from the closure, because there will be a dependency on
2745-
// those type variables once we have generated constraints for the
2746-
// closure body.
2747-
struct CollectParameterRefs : public ASTWalker {
2743+
// Collect any variable references whose types involve type variables,
2744+
// because there will be a dependency on those type variables once we have
2745+
// generated constraints for the closure body. This includes references
2746+
// to other closure params such as in `{ x in { x }}` where the inner
2747+
// closure is dependent on the outer closure's param type, as well as
2748+
// cases like `for i in x where bar({ i })` where there's a dependency on
2749+
// the type variable for the pattern `i`.
2750+
struct CollectVarRefs : public ASTWalker {
27482751
ConstraintSystem &cs;
2749-
llvm::SmallVector<TypeVariableType *, 4> paramRefs;
2752+
llvm::SmallVector<TypeVariableType *, 4> varRefs;
27502753
bool hasErrorExprs = false;
27512754

2752-
CollectParameterRefs(ConstraintSystem &cs) : cs(cs) { }
2755+
CollectVarRefs(ConstraintSystem &cs) : cs(cs) { }
27532756

27542757
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
27552758
// If there are any error expressions in this closure
@@ -2759,21 +2762,21 @@ namespace {
27592762
return {false, nullptr};
27602763
}
27612764

2762-
// Retrieve type variables from references to parameter declarations.
2765+
// Retrieve type variables from references to var decls.
27632766
if (auto *declRef = dyn_cast<DeclRefExpr>(expr)) {
2764-
if (auto *paramDecl = dyn_cast<ParamDecl>(declRef->getDecl())) {
2765-
if (Type paramType = cs.getTypeIfAvailable(paramDecl)) {
2766-
paramType->getTypeVariables(paramRefs);
2767+
if (auto *varDecl = dyn_cast<VarDecl>(declRef->getDecl())) {
2768+
if (auto varType = cs.getTypeIfAvailable(varDecl)) {
2769+
varType->getTypeVariables(varRefs);
27672770
}
27682771
}
27692772
}
27702773

27712774
return { true, expr };
27722775
}
2773-
} collectParameterRefs(CS);
2774-
closure->walk(collectParameterRefs);
2776+
} collectVarRefs(CS);
2777+
closure->walk(collectVarRefs);
27752778

2776-
if (collectParameterRefs.hasErrorExprs)
2779+
if (collectVarRefs.hasErrorExprs)
27772780
return Type();
27782781

27792782
auto inferredType = inferClosureType(closure);
@@ -2783,7 +2786,7 @@ namespace {
27832786
CS.addUnsolvedConstraint(
27842787
Constraint::create(CS, ConstraintKind::DefaultClosureType,
27852788
closureType, inferredType, locator,
2786-
collectParameterRefs.paramRefs));
2789+
collectVarRefs.varRefs));
27872790

27882791
CS.setClosureType(closure, inferredType);
27892792
return closureType;

test/stmt/foreach.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,12 @@ func sr_12398(arr1: [Int], arr2: [(a: Int, b: String)]) {
229229
for (x, y, _) in arr2 {}
230230
// expected-error@-1 {{pattern cannot match values of type '(a: Int, b: String)'}}
231231
}
232+
233+
// rdar://62339835
234+
func testForEachWhereWithClosure(_ x: [Int]) {
235+
func foo<T>(_ fn: () -> T) -> Bool { true }
236+
237+
for i in x where foo({ i }) {}
238+
for i in x where foo({ i.byteSwapped == 5 }) {}
239+
for i in x where x.contains(where: { $0.byteSwapped == i }) {}
240+
}

0 commit comments

Comments
 (0)