Skip to content

Commit 1b8f562

Browse files
committed
Adjust the referenced function type for @_unsafeSendable and @_unsafeMainActor.
When referencing a function that contains parameters with the hidden `@_unsafeSendable` or `@_unsafeMainActor` attributes, adjust the function type to make the types of those parameters `@Sendable` or `@MainActor`, respectively, based on both the context the expression: * `@Sendable` will be applied when we are in a context with strict concurrency checking. * `@MainActor` will be applied when we are in a context with strict concurrency checking *or* the function is being directly applied so that an argument is provided in the immediate expression. The second part of the rule of `@MainActor` reflects the fact that making the parameter `@MainActor` doesn't break existing code (because there is a conversion to add a global actor to a function value), but it does enable such code to synchronously use a `@MainActor`-qualified API. The main effect of this change is that, in a strict concurrency context, the type of referencing an unapplied function involving `@_unsafeSendable` or `@_unsafeMainActor` in a strict context will make those parameters `@Sendable` or `@MainActor`, which ensures that these constraints properly work with non-closure arguments. The former solution only applied to closure literals, which left some holes in Sendable checking. Fixes rdar://77753021.
1 parent 9db2638 commit 1b8f562

File tree

7 files changed

+280
-18
lines changed

7 files changed

+280
-18
lines changed

lib/Sema/CSApply.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5825,16 +5825,14 @@ ArgumentList *ExprRewriter::coerceCallArguments(
58255825
// for things like trailing closures and args to property wrapper params.
58265826
arg.setLabel(param.getLabel());
58275827

5828-
// Determine whether the parameter is unsafe Sendable or MainActor, and
5829-
// record it as such.
5830-
bool isUnsafeSendable = paramInfo.isUnsafeSendable(paramIdx);
5831-
bool isMainActor = paramInfo.isUnsafeMainActor(paramIdx) ||
5832-
(isUnsafeSendable && apply && isMainDispatchQueue(apply->getFn()));
5828+
// Determine whether the closure argument should be treated as being on
5829+
// the main actor, having implicit self capture, or inheriting actor
5830+
// context.
58335831
bool isImplicitSelfCapture = paramInfo.isImplicitSelfCapture(paramIdx);
58345832
bool inheritsActorContext = paramInfo.inheritsActorContext(paramIdx);
58355833
applyContextualClosureFlags(
5836-
argExpr, isUnsafeSendable && contextUsesConcurrencyFeatures(dc),
5837-
isMainActor, isImplicitSelfCapture, inheritsActorContext);
5834+
argExpr, false,
5835+
false, isImplicitSelfCapture, inheritsActorContext);
58385836

58395837
// If the types exactly match, this is easy.
58405838
auto paramType = param.getOldType();

lib/Sema/ConstraintSystem.cpp

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,24 @@ static unsigned getNumRemovedArgumentLabels(ValueDecl *decl,
12551255
llvm_unreachable("Unhandled FunctionRefKind in switch.");
12561256
}
12571257

1258+
/// Determine the number of applications
1259+
static unsigned getNumApplications(
1260+
ValueDecl *decl, bool hasAppliedSelf, FunctionRefKind functionRefKind) {
1261+
switch (functionRefKind) {
1262+
case FunctionRefKind::Unapplied:
1263+
case FunctionRefKind::Compound:
1264+
return 0 + hasAppliedSelf;
1265+
1266+
case FunctionRefKind::SingleApply:
1267+
return 1 + hasAppliedSelf;
1268+
1269+
case FunctionRefKind::DoubleApply:
1270+
return 2;
1271+
}
1272+
1273+
llvm_unreachable("Unhandled FunctionRefKind in switch.");
1274+
}
1275+
12581276
/// Replaces property wrapper types in the parameter list of the given function type
12591277
/// with the wrapped-value or projected-value types (depending on argument label).
12601278
static FunctionType *
@@ -1336,8 +1354,10 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
13361354

13371355
AnyFunctionType *funcType = func->getInterfaceType()
13381356
->castTo<AnyFunctionType>();
1339-
if (!isRequirementOrWitness(locator))
1340-
funcType = applyGlobalActorType(funcType, func, useDC);
1357+
if (!isRequirementOrWitness(locator)) {
1358+
unsigned numApplies = getNumApplications(value, false, functionRefKind);
1359+
funcType = applyGlobalActorType(funcType, func, useDC, numApplies, false);
1360+
}
13411361
auto openedType = openFunctionType(
13421362
funcType, locator, replacements, func->getDeclContext());
13431363

@@ -1366,8 +1386,12 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
13661386
auto numLabelsToRemove = getNumRemovedArgumentLabels(
13671387
funcDecl, /*isCurriedInstanceReference=*/false, functionRefKind);
13681388

1369-
if (!isRequirementOrWitness(locator))
1370-
funcType = applyGlobalActorType(funcType, funcDecl, useDC);
1389+
if (!isRequirementOrWitness(locator)) {
1390+
unsigned numApplies = getNumApplications(
1391+
funcDecl, false, functionRefKind);
1392+
funcType = applyGlobalActorType(
1393+
funcType, funcDecl, useDC, numApplies, false);
1394+
}
13711395

13721396
auto openedType = openFunctionType(funcType, locator, replacements,
13731397
funcDecl->getDeclContext())
@@ -1634,6 +1658,72 @@ Type constraints::getDynamicSelfReplacementType(
16341658
->getMetatypeInstanceType();
16351659
}
16361660

1661+
/// Determine whether the given name is that of a DispatchQueue operation that
1662+
/// takes a closure to be executed on the queue.
1663+
static bool isDispatchQueueOperationName(StringRef name) {
1664+
return llvm::StringSwitch<bool>(name)
1665+
.Case("sync", true)
1666+
.Case("async", true)
1667+
.Case("asyncAndWait", true)
1668+
.Case("asyncAfter", true)
1669+
.Case("concurrentPerform", true)
1670+
.Default(false);
1671+
}
1672+
1673+
/// Determine whether this locator refers to a member of "DispatchQueue.main",
1674+
/// which is a special dispatch queue that executes its work on the main actor.
1675+
static bool isMainDispatchQueueMember(ConstraintLocator *locator) {
1676+
if (!locator)
1677+
return false;
1678+
1679+
if (locator->getPath().size() != 1 ||
1680+
!locator->isLastElement<LocatorPathElt::Member>())
1681+
return false;
1682+
1683+
auto expr = locator->getAnchor().dyn_cast<Expr *>();
1684+
if (!expr)
1685+
return false;
1686+
1687+
auto outerUnresolvedDot = dyn_cast<UnresolvedDotExpr>(expr);
1688+
if (!outerUnresolvedDot)
1689+
return false;
1690+
1691+
1692+
if (!isDispatchQueueOperationName(
1693+
outerUnresolvedDot->getName().getBaseName().userFacingName()))
1694+
return false;
1695+
1696+
auto innerUnresolvedDot = dyn_cast<UnresolvedDotExpr>(
1697+
outerUnresolvedDot->getBase());
1698+
if (!innerUnresolvedDot)
1699+
return false;
1700+
1701+
if (innerUnresolvedDot->getName().getBaseName().userFacingName() != "main")
1702+
return false;
1703+
1704+
auto typeExpr = dyn_cast<TypeExpr>(innerUnresolvedDot->getBase());
1705+
if (!typeExpr)
1706+
return false;
1707+
1708+
auto typeRepr = typeExpr->getTypeRepr();
1709+
if (!typeRepr)
1710+
return false;
1711+
1712+
auto identTypeRepr = dyn_cast<IdentTypeRepr>(typeRepr);
1713+
if (!identTypeRepr)
1714+
return false;
1715+
1716+
auto components = identTypeRepr->getComponentRange();
1717+
if (components.empty())
1718+
return false;
1719+
1720+
if (components.back()->getNameRef().getBaseName().userFacingName() !=
1721+
"DispatchQueue")
1722+
return false;
1723+
1724+
return true;
1725+
}
1726+
16371727
std::pair<Type, Type>
16381728
ConstraintSystem::getTypeOfMemberReference(
16391729
Type baseTy, ValueDecl *value, DeclContext *useDC,
@@ -1711,8 +1801,13 @@ ConstraintSystem::getTypeOfMemberReference(
17111801
// This is the easy case.
17121802
funcType = value->getInterfaceType()->castTo<AnyFunctionType>();
17131803

1714-
if (!isRequirementOrWitness(locator))
1715-
funcType = applyGlobalActorType(funcType, value, useDC);
1804+
if (!isRequirementOrWitness(locator)) {
1805+
unsigned numApplies = getNumApplications(
1806+
value, hasAppliedSelf, functionRefKind);
1807+
funcType = applyGlobalActorType(
1808+
funcType, value, useDC, numApplies,
1809+
isMainDispatchQueueMember(locator));
1810+
}
17161811
} else {
17171812
// For a property, build a type (Self) -> PropType.
17181813
// For a subscript, build a type (Self) -> (Indices...) -> ElementType.

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4035,8 +4035,119 @@ NormalProtocolConformance *GetImplicitSendableRequest::evaluate(
40354035
return formConformance(nullptr);
40364036
}
40374037

4038+
/// Apply @Sendable and/or @MainActor to the given parameter type.
4039+
static Type applyUnsafeConcurrencyToParameterType(
4040+
Type type, bool sendable, bool mainActor) {
4041+
if (Type objectType = type->getOptionalObjectType()) {
4042+
return OptionalType::get(
4043+
applyUnsafeConcurrencyToParameterType(objectType, sendable, mainActor));
4044+
}
4045+
4046+
auto fnType = type->getAs<FunctionType>();
4047+
if (!fnType)
4048+
return type;
4049+
4050+
Type globalActor;
4051+
if (mainActor)
4052+
globalActor = type->getASTContext().getMainActorType();
4053+
4054+
return fnType->withExtInfo(fnType->getExtInfo()
4055+
.withConcurrent(sendable)
4056+
.withGlobalActor(globalActor));
4057+
}
4058+
4059+
/// Apply @_unsafeSendable and @_unsafeMainActor to the parameters of the
4060+
/// given function.
4061+
static AnyFunctionType *applyUnsafeConcurrencyToFunctionType(
4062+
AnyFunctionType *fnType, ValueDecl *funcOrEnum,
4063+
bool inConcurrencyContext, unsigned numApplies, bool isMainDispatchQueue) {
4064+
// Only functions can have @_unsafeSendable/@_unsafeMainActor parameters.
4065+
auto func = dyn_cast<AbstractFunctionDecl>(funcOrEnum);
4066+
if (!func)
4067+
return fnType;
4068+
4069+
AnyFunctionType *outerFnType = nullptr;
4070+
if (func->hasImplicitSelfDecl()) {
4071+
outerFnType = fnType;
4072+
fnType = outerFnType->getResult()->castTo<AnyFunctionType>();
4073+
4074+
if (numApplies > 0)
4075+
--numApplies;
4076+
}
4077+
4078+
SmallVector<AnyFunctionType::Param, 4> newTypeParams;
4079+
auto typeParams = fnType->getParams();
4080+
auto paramDecls = func->getParameters();
4081+
assert(typeParams.size() == paramDecls->size());
4082+
for (unsigned index : indices(typeParams)) {
4083+
auto param = typeParams[index];
4084+
auto paramDecl = (*paramDecls)[index];
4085+
4086+
// Determine whether the resulting parameter should be @Sendable or
4087+
// @MainActor. @Sendable occurs only in concurrency contents, while
4088+
// @MainActor occurs in concurrency contexts or those where we have an
4089+
// application.
4090+
bool isSendable =
4091+
(paramDecl->getAttrs().hasAttribute<UnsafeSendableAttr>() ||
4092+
func->hasKnownUnsafeSendableFunctionParams()) &&
4093+
inConcurrencyContext;
4094+
bool isMainActor =
4095+
(paramDecl->getAttrs().hasAttribute<UnsafeMainActorAttr>() ||
4096+
(isMainDispatchQueue &&
4097+
func->hasKnownUnsafeSendableFunctionParams())) &&
4098+
(inConcurrencyContext || numApplies >= 1);
4099+
4100+
if (!isSendable && !isMainActor) {
4101+
// If any prior parameter has changed, record this one.
4102+
if (!newTypeParams.empty())
4103+
newTypeParams.push_back(param);
4104+
continue;
4105+
}
4106+
4107+
// If this is the first parameter to have changed, copy all of the others
4108+
// over.
4109+
if (newTypeParams.empty()) {
4110+
newTypeParams.append(typeParams.begin(), typeParams.begin() + index);
4111+
}
4112+
4113+
4114+
// Transform the parameter type.
4115+
Type newParamType = applyUnsafeConcurrencyToParameterType(
4116+
param.getPlainType(), isSendable, isMainActor);
4117+
newTypeParams.push_back(param.withType(newParamType));
4118+
}
4119+
4120+
// If we didn't change any parameters, we're done.
4121+
if (newTypeParams.empty()) {
4122+
return outerFnType ? outerFnType : fnType;
4123+
}
4124+
4125+
// Rebuild the (inner) function type.
4126+
fnType = FunctionType::get(
4127+
newTypeParams, fnType->getResult(), fnType->getExtInfo());
4128+
4129+
if (!outerFnType)
4130+
return fnType;
4131+
4132+
// Rebuild the outer function type.
4133+
if (auto genericFnType = dyn_cast<GenericFunctionType>(outerFnType)) {
4134+
return GenericFunctionType::get(
4135+
genericFnType->getGenericSignature(), outerFnType->getParams(),
4136+
Type(fnType), outerFnType->getExtInfo());
4137+
}
4138+
4139+
return FunctionType::get(
4140+
outerFnType->getParams(), Type(fnType), outerFnType->getExtInfo());
4141+
}
4142+
40384143
AnyFunctionType *swift::applyGlobalActorType(
4039-
AnyFunctionType *fnType, ValueDecl *funcOrEnum, DeclContext *dc) {
4144+
AnyFunctionType *fnType, ValueDecl *funcOrEnum, DeclContext *dc,
4145+
unsigned numApplies, bool isMainDispatchQueue) {
4146+
// Apply unsafe concurrency features to the given function type.
4147+
fnType = applyUnsafeConcurrencyToFunctionType(
4148+
fnType, funcOrEnum, contextUsesConcurrencyFeatures(dc), numApplies,
4149+
isMainDispatchQueue);
4150+
40404151
Type globalActorType;
40414152
switch (auto isolation = getActorIsolation(funcOrEnum)) {
40424153
case ActorIsolation::ActorInstance:

lib/Sema/TypeChecker.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,8 @@ void bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *stmt);
13241324
/// references to its function type from the given declaration context,
13251325
/// update the given function type to include the global actor.
13261326
AnyFunctionType *applyGlobalActorType(
1327-
AnyFunctionType *fnType, ValueDecl *funcOrEnum, DeclContext *dc);
1327+
AnyFunctionType *fnType, ValueDecl *funcOrEnum, DeclContext *dc,
1328+
unsigned numApplies, bool isMainDispatchQueue);
13281329

13291330
/// If \p attr was added by an access note, wraps the error in
13301331
/// \c diag::wrap_invalid_attr_added_by_access_note and limits it as an access

test/Concurrency/dispatch_inference.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@ func testMe() {
1414
}
1515
}
1616

17-
func testUnsafeSendableInAsync() async {
17+
func testUnsafeSendableInMainAsync() async {
1818
var x = 5
1919
DispatchQueue.main.async {
2020
x = 17 // expected-error{{mutation of captured var 'x' in concurrently-executing code}}
2121
}
2222
print(x)
2323
}
24+
25+
func testUnsafeSendableInAsync(queue: DispatchQueue) async {
26+
var x = 5
27+
queue.async {
28+
x = 17 // expected-error{{mutation of captured var 'x' in concurrently-executing code}}
29+
}
30+
print(x)
31+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %target-typecheck-verify-swift -disable-availability-checking
2+
3+
func unsafelySendableClosure(@_unsafeSendable _ closure: () -> Void) { }
4+
5+
func unsafelyMainActorClosure(@_unsafeMainActor _ closure: () -> Void) { }
6+
7+
func unsafelyDoEverythingClosure(@_unsafeSendable @_unsafeMainActor _ closure: () -> Void) { }
8+
9+
struct X {
10+
func unsafelyDoEverythingClosure(@_unsafeSendable @_unsafeMainActor _ closure: () -> Void) { }
11+
}
12+
13+
14+
func testInAsync(x: X) async {
15+
let _: Int = unsafelySendableClosure // expected-error{{type '(@Sendable () -> Void) -> ()'}}
16+
let _: Int = unsafelyMainActorClosure // expected-error{{type '(@MainActor () -> Void) -> ()'}}
17+
let _: Int = unsafelyDoEverythingClosure // expected-error{{type '(@MainActor @Sendable () -> Void) -> ()'}}
18+
let _: Int = x.unsafelyDoEverythingClosure // expected-error{{type '(@MainActor @Sendable () -> Void) -> ()'}}
19+
let _: Int = X.unsafelyDoEverythingClosure // expected-error{{type '(X) -> (@MainActor @Sendable () -> Void) -> ()'}}
20+
let _: Int = (X.unsafelyDoEverythingClosure)(x) // expected-error{{type '(@MainActor @Sendable () -> Void) -> ()'}}
21+
}
22+
23+
func testElsewhere(x: X) {
24+
let _: Int = unsafelySendableClosure // expected-error{{type '(() -> Void) -> ()'}}
25+
let _: Int = unsafelyMainActorClosure // expected-error{{type '(() -> Void) -> ()'}}
26+
let _: Int = unsafelyDoEverythingClosure // expected-error{{type '(() -> Void) -> ()'}}
27+
let _: Int = x.unsafelyDoEverythingClosure // expected-error{{type '(() -> Void) -> ()'}}
28+
let _: Int = X.unsafelyDoEverythingClosure // expected-error{{type '(X) -> (() -> Void) -> ()'}}
29+
let _: Int = (X.unsafelyDoEverythingClosure)(x) // expected-error{{type '(() -> Void) -> ()'}}
30+
}
31+
32+
@MainActor func onMainActor() { }
33+
34+
func testCalls(x: X) {
35+
unsafelyMainActorClosure {
36+
onMainActor()
37+
}
38+
39+
unsafelyDoEverythingClosure {
40+
onMainActor()
41+
}
42+
43+
x.unsafelyDoEverythingClosure {
44+
onMainActor()
45+
}
46+
(X.unsafelyDoEverythingClosure)(x)( {
47+
onMainActor()
48+
})
49+
}

test/SILGen/check_executor.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ public actor MyActor {
3232
}
3333
}
3434

35-
// CHECK-RAW-LABEL: sil private [ossa] @$s4test7MyActorC0A10UnsafeMainyyFSiycfU_
36-
// CHECK-RAW-NOT: _checkExpectedExecutor
35+
// CHECK-RAW-LABEL: sil private [ossa] @$s4test7MyActorC0A10UnsafeMainyyFSiyScMYccfU_
36+
// CHECK-RAW: _checkExpectedExecutor
3737
// CHECK-RAW: onMainActor
3838
// CHECK-RAW: return
3939
public func testUnsafeMain() {

0 commit comments

Comments
 (0)