Skip to content

Commit 2b5bcc1

Browse files
authored
Merge pull request #40544 from DougGregor/sr-15131-closure-effects
Extend "uses concurrency features" checks for closures currently being type checked
2 parents f043a48 + cc7904c commit 2b5bcc1

File tree

10 files changed

+206
-86
lines changed

10 files changed

+206
-86
lines changed

include/swift/AST/Expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3615,6 +3615,9 @@ class AbstractClosureExpr : public DeclContext, public Expr {
36153615
/// \brief Return whether this closure is async when fully applied.
36163616
bool isBodyAsync() const;
36173617

3618+
/// Whether this closure is Sendable.
3619+
bool isSendable() const;
3620+
36183621
/// Whether this closure consists of a single expression.
36193622
bool hasSingleExpressionBody() const;
36203623

include/swift/AST/TypeCheckRequests.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3191,6 +3191,23 @@ class RenamedDeclRequest
31913191
bool isCached() const { return true; }
31923192
};
31933193

3194+
class ClosureEffectsRequest
3195+
: public SimpleRequest<ClosureEffectsRequest,
3196+
FunctionType::ExtInfo(ClosureExpr *),
3197+
RequestFlags::Cached> {
3198+
public:
3199+
using SimpleRequest::SimpleRequest;
3200+
3201+
private:
3202+
friend SimpleRequest;
3203+
3204+
FunctionType::ExtInfo evaluate(
3205+
Evaluator &evaluator, ClosureExpr *closure) const;
3206+
3207+
public:
3208+
bool isCached() const { return true; }
3209+
};
3210+
31943211
void simple_display(llvm::raw_ostream &out, Type value);
31953212
void simple_display(llvm::raw_ostream &out, const TypeRepr *TyR);
31963213
void simple_display(llvm::raw_ostream &out, ImplicitMemberAction action);

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,6 @@ SWIFT_REQUEST(TypeChecker, GetImplicitSendableRequest,
363363
SWIFT_REQUEST(TypeChecker, RenamedDeclRequest,
364364
ValueDecl *(const ValueDecl *),
365365
Cached, NoLocationInfo)
366+
SWIFT_REQUEST(TypeChecker, ClosureEffectsRequest,
367+
FunctionType::ExtInfo(ClosureExpr *),
368+
Cached, NoLocationInfo)

include/swift/Sema/ConstraintSystem.h

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,13 @@ enum class SolutionApplicationToFunctionResult {
22172217
Delay,
22182218
};
22192219

2220+
/// Retrieve the closure type from the constraint system.
2221+
struct GetClosureType {
2222+
ConstraintSystem &cs;
2223+
2224+
Type operator()(const AbstractClosureExpr *expr) const;
2225+
};
2226+
22202227
/// Describes a system of constraints on type variables, the
22212228
/// solution of which assigns concrete types to each of the type variables.
22222229
/// Constraint systems are typically generated given an (untyped) expression.
@@ -2443,9 +2450,6 @@ class ConstraintSystem {
24432450
llvm::MapVector<AnyFunctionRef, AppliedBuilderTransform>
24442451
resultBuilderTransformed;
24452452

2446-
/// Cache of the effects any closures visited.
2447-
llvm::SmallDenseMap<ClosureExpr *, FunctionType::ExtInfo, 4> closureEffectsCache;
2448-
24492453
/// A mapping from the constraint locators for references to various
24502454
/// names (e.g., member references, normal name references, possible
24512455
/// constructions) to the argument lists for the call to that locator.
@@ -3099,9 +3103,16 @@ class ConstraintSystem {
30993103
}
31003104

31013105
FunctionType *getClosureType(const ClosureExpr *closure) const {
3106+
auto result = getClosureTypeIfAvailable(closure);
3107+
assert(result);
3108+
return result;
3109+
}
3110+
3111+
FunctionType *getClosureTypeIfAvailable(const ClosureExpr *closure) const {
31023112
auto result = ClosureTypes.find(closure);
3103-
assert(result != ClosureTypes.end());
3104-
return result->second;
3113+
if (result != ClosureTypes.end())
3114+
return result->second;
3115+
return nullptr;
31053116
}
31063117

31073118
TypeBase* getFavoredType(Expr *E) {
@@ -4119,6 +4130,12 @@ class ConstraintSystem {
41194130
ConstraintLocatorBuilder locator,
41204131
const OpenedTypeMap &replacements);
41214132

4133+
/// Wrapper over swift::adjustFunctionTypeForConcurrency that passes along
4134+
/// the appropriate closure-type extraction function.
4135+
AnyFunctionType *adjustFunctionTypeForConcurrency(
4136+
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
4137+
unsigned numApplies, bool isMainDispatchQueue);
4138+
41224139
/// Retrieve the type of a reference to the given value declaration.
41234140
///
41244141
/// For references to polymorphic function types, this routine "opens up"
@@ -4167,10 +4184,15 @@ class ConstraintSystem {
41674184
///
41684185
/// \param getType Optional callback to extract a type for given declaration.
41694186
static Type
4170-
getUnopenedTypeOfReference(VarDecl *value, Type baseType, DeclContext *UseDC,
4171-
llvm::function_ref<Type(VarDecl *)> getType,
4172-
ConstraintLocator *memberLocator = nullptr,
4173-
bool wantInterfaceType = false);
4187+
getUnopenedTypeOfReference(
4188+
VarDecl *value, Type baseType, DeclContext *UseDC,
4189+
llvm::function_ref<Type(VarDecl *)> getType,
4190+
ConstraintLocator *memberLocator = nullptr,
4191+
bool wantInterfaceType = false,
4192+
llvm::function_ref<Type(const AbstractClosureExpr *)> getClosureType =
4193+
[](const AbstractClosureExpr *) {
4194+
return Type();
4195+
});
41744196

41754197
/// Retrieve the type of a reference to the given value declaration,
41764198
/// as a member with a base of the given type.

lib/AST/Expr.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,15 +1708,49 @@ Type AbstractClosureExpr::getResultType(
17081708
}
17091709

17101710
bool AbstractClosureExpr::isBodyThrowing() const {
1711-
if (getType()->hasError())
1711+
if (!getType() || getType()->hasError()) {
1712+
// Scan the closure body to infer effects.
1713+
if (auto closure = dyn_cast<ClosureExpr>(this)) {
1714+
return evaluateOrDefault(
1715+
getASTContext().evaluator,
1716+
ClosureEffectsRequest{const_cast<ClosureExpr *>(closure)},
1717+
FunctionType::ExtInfo()).isThrowing();
1718+
}
1719+
17121720
return false;
1721+
}
17131722

17141723
return getType()->castTo<FunctionType>()->getExtInfo().isThrowing();
17151724
}
17161725

1726+
bool AbstractClosureExpr::isSendable() const {
1727+
if (!getType() || getType()->hasError()) {
1728+
// Scan the closure body to infer effects.
1729+
if (auto closure = dyn_cast<ClosureExpr>(this)) {
1730+
return evaluateOrDefault(
1731+
getASTContext().evaluator,
1732+
ClosureEffectsRequest{const_cast<ClosureExpr *>(closure)},
1733+
FunctionType::ExtInfo()).isSendable();
1734+
}
1735+
1736+
return false;
1737+
}
1738+
1739+
return getType()->castTo<FunctionType>()->getExtInfo().isSendable();
1740+
}
1741+
17171742
bool AbstractClosureExpr::isBodyAsync() const {
1718-
if (getType()->hasError())
1743+
if (!getType() || getType()->hasError()) {
1744+
// Scan the closure body to infer effects.
1745+
if (auto closure = dyn_cast<ClosureExpr>(this)) {
1746+
return evaluateOrDefault(
1747+
getASTContext().evaluator,
1748+
ClosureEffectsRequest{const_cast<ClosureExpr *>(closure)},
1749+
FunctionType::ExtInfo()).isAsync();
1750+
}
1751+
17191752
return false;
1753+
}
17201754

17211755
return getType()->castTo<FunctionType>()->getExtInfo().isAsync();
17221756
}

lib/Sema/CSSimplify.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2171,7 +2171,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
21712171

21722172
/// Whether to downgrade to a concurrency warning.
21732173
auto isConcurrencyWarning = [&] {
2174-
if (contextRequiresStrictConcurrencyChecking(DC))
2174+
if (contextRequiresStrictConcurrencyChecking(DC, GetClosureType{*this}))
21752175
return false;
21762176

21772177
switch (kind) {

lib/Sema/ConstraintSystem.cpp

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,21 @@ doesStorageProduceLValue(AbstractStorageDecl *storage, Type baseType,
10981098
!storage->isSetterMutating());
10991099
}
11001100

1101+
Type GetClosureType::operator()(const AbstractClosureExpr *expr) const {
1102+
if (auto closure = dyn_cast<ClosureExpr>(expr)) {
1103+
// Look through type bindings, if we have them.
1104+
auto mutableClosure = const_cast<ClosureExpr *>(closure);
1105+
if (cs.hasType(mutableClosure)) {
1106+
return cs.getFixedTypeRecursive(
1107+
cs.getType(mutableClosure), /*wantRValue=*/true);
1108+
}
1109+
1110+
return cs.getClosureTypeIfAvailable(closure);
1111+
}
1112+
1113+
return Type();
1114+
}
1115+
11011116
Type ConstraintSystem::getUnopenedTypeOfReference(
11021117
VarDecl *value, Type baseType, DeclContext *UseDC,
11031118
ConstraintLocator *memberLocator, bool wantInterfaceType) {
@@ -1113,18 +1128,20 @@ Type ConstraintSystem::getUnopenedTypeOfReference(
11131128

11141129
return wantInterfaceType ? var->getInterfaceType() : var->getType();
11151130
},
1116-
memberLocator, wantInterfaceType);
1131+
memberLocator, wantInterfaceType, GetClosureType{*this});
11171132
}
11181133

11191134
Type ConstraintSystem::getUnopenedTypeOfReference(
11201135
VarDecl *value, Type baseType, DeclContext *UseDC,
11211136
llvm::function_ref<Type(VarDecl *)> getType,
1122-
ConstraintLocator *memberLocator, bool wantInterfaceType) {
1137+
ConstraintLocator *memberLocator, bool wantInterfaceType,
1138+
llvm::function_ref<Type(const AbstractClosureExpr *)> getClosureType) {
11231139
Type requestedType =
11241140
getType(value)->getWithoutSpecifierType()->getReferenceStorageReferent();
11251141

11261142
// Adjust the type for concurrency.
1127-
requestedType = adjustVarTypeForConcurrency(requestedType, value, UseDC);
1143+
requestedType = adjustVarTypeForConcurrency(
1144+
requestedType, value, UseDC, getClosureType);
11281145

11291146
// If we're dealing with contextual types, and we referenced this type from
11301147
// a different context, map the type.
@@ -1298,6 +1315,13 @@ static bool isRequirementOrWitness(const ConstraintLocatorBuilder &locator) {
12981315
return false;
12991316
}
13001317

1318+
AnyFunctionType *ConstraintSystem::adjustFunctionTypeForConcurrency(
1319+
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
1320+
unsigned numApplies, bool isMainDispatchQueue) {
1321+
return swift::adjustFunctionTypeForConcurrency(
1322+
fnType, decl, dc, numApplies, isMainDispatchQueue, GetClosureType{*this});
1323+
}
1324+
13011325
std::pair<Type, Type>
13021326
ConstraintSystem::getTypeOfReference(ValueDecl *value,
13031327
FunctionRefKind functionRefKind,
@@ -1314,7 +1338,8 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
13141338
->castTo<AnyFunctionType>();
13151339
if (!isRequirementOrWitness(locator)) {
13161340
unsigned numApplies = getNumApplications(value, false, functionRefKind);
1317-
funcType = adjustFunctionTypeForConcurrency(funcType, func, useDC, numApplies, false);
1341+
funcType = adjustFunctionTypeForConcurrency(
1342+
funcType, func, useDC, numApplies, false);
13181343
}
13191344
auto openedType = openFunctionType(
13201345
funcType, locator, replacements, func->getDeclContext());
@@ -2085,7 +2110,8 @@ Type ConstraintSystem::getEffectiveOverloadType(ConstraintLocator *locator,
20852110
} else if (type->hasDynamicSelfType()) {
20862111
type = withDynamicSelfResultReplaced(type, /*uncurryLevel=*/0);
20872112
}
2088-
type = adjustVarTypeForConcurrency(type, var, useDC);
2113+
type = adjustVarTypeForConcurrency(
2114+
type, var, useDC, GetClosureType{*this});
20892115
} else if (isa<AbstractFunctionDecl>(decl) || isa<EnumElementDecl>(decl)) {
20902116
if (decl->isInstanceMember() &&
20912117
(!overload.getBaseType() ||
@@ -2339,20 +2365,18 @@ isInvalidPartialApplication(ConstraintSystem &cs,
23392365
return {true, level};
23402366
}
23412367

2342-
/// Walk a closure AST to determine its effects.
2343-
///
2344-
/// \returns a function's extended info describing the effects, as
2345-
/// determined syntactically.
23462368
FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
2347-
auto known = closureEffectsCache.find(expr);
2348-
if (known != closureEffectsCache.end())
2349-
return known->second;
2369+
return evaluateOrDefault(
2370+
getASTContext().evaluator, ClosureEffectsRequest{expr},
2371+
FunctionType::ExtInfo());
2372+
}
23502373

2374+
FunctionType::ExtInfo ClosureEffectsRequest::evaluate(
2375+
Evaluator &evaluator, ClosureExpr *expr) const {
23512376
// A walker that looks for 'try' and 'throw' expressions
23522377
// that aren't nested within closures, nested declarations,
23532378
// or exhaustive catches.
23542379
class FindInnerThrows : public ASTWalker {
2355-
ConstraintSystem &CS;
23562380
DeclContext *DC;
23572381
bool FoundThrow = false;
23582382

@@ -2449,7 +2473,7 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
24492473
// Okay, now it should be safe to coerce the pattern.
24502474
// Pull the top-level pattern back out.
24512475
pattern = LabelItem.getPattern();
2452-
Type exnType = CS.getASTContext().getErrorDecl()->getDeclaredInterfaceType();
2476+
Type exnType = DC->getASTContext().getErrorDecl()->getDeclaredInterfaceType();
24532477

24542478
if (!exnType)
24552479
return false;
@@ -2501,8 +2525,8 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
25012525
}
25022526

25032527
public:
2504-
FindInnerThrows(ConstraintSystem &cs, DeclContext *dc)
2505-
: CS(cs), DC(dc) {}
2528+
FindInnerThrows(DeclContext *dc)
2529+
: DC(dc) {}
25062530

25072531
bool foundThrow() { return FoundThrow; }
25082532
};
@@ -2525,23 +2549,27 @@ FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
25252549
if (!body)
25262550
return ASTExtInfoBuilder().withConcurrent(concurrent).build();
25272551

2528-
auto throwFinder = FindInnerThrows(*this, expr);
2552+
auto throwFinder = FindInnerThrows(expr);
25292553
body->walk(throwFinder);
2530-
auto result = ASTExtInfoBuilder()
2531-
.withThrows(throwFinder.foundThrow())
2532-
.withAsync(bool(findAsyncNode(expr)))
2533-
.withConcurrent(concurrent)
2534-
.build();
2535-
closureEffectsCache[expr] = result;
2536-
return result;
2554+
return ASTExtInfoBuilder()
2555+
.withThrows(throwFinder.foundThrow())
2556+
.withAsync(bool(findAsyncNode(expr)))
2557+
.withConcurrent(concurrent)
2558+
.build();
25372559
}
25382560

25392561
bool ConstraintSystem::isAsynchronousContext(DeclContext *dc) {
25402562
if (auto func = dyn_cast<AbstractFunctionDecl>(dc))
25412563
return func->isAsyncContext();
25422564

2543-
if (auto closure = dyn_cast<ClosureExpr>(dc))
2544-
return closureEffects(closure).isAsync();
2565+
if (auto abstractClosure = dyn_cast<AbstractClosureExpr>(dc)) {
2566+
if (Type type = GetClosureType{*this}(abstractClosure)) {
2567+
if (auto fnType = type->getAs<AnyFunctionType>())
2568+
return fnType->isAsync();
2569+
}
2570+
2571+
return abstractClosure->isBodyAsync();
2572+
}
25452573

25462574
return false;
25472575
}

0 commit comments

Comments
 (0)