Skip to content

[6.0][TypeChecker] InferSendableFromCaptures: Infer @Sendable on adjusted types #72741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -4370,8 +4370,8 @@ class ConstraintSystem {

/// Wrapper over swift::adjustFunctionTypeForConcurrency that passes along
/// the appropriate closure-type and opening extraction functions.
AnyFunctionType *adjustFunctionTypeForConcurrency(
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
FunctionType *adjustFunctionTypeForConcurrency(
FunctionType *fnType, Type baseType, ValueDecl *decl, DeclContext *dc,
unsigned numApplies, bool isMainDispatchQueue,
OpenedTypeMap &replacements, ConstraintLocatorBuilder locator);

Expand Down
85 changes: 46 additions & 39 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1695,20 +1695,46 @@ static bool isRequirementOrWitness(const ConstraintLocatorBuilder &locator) {
locator.endsWith<LocatorPathElt::Witness>();
}

AnyFunctionType *ConstraintSystem::adjustFunctionTypeForConcurrency(
AnyFunctionType *fnType, ValueDecl *decl, DeclContext *dc,
FunctionType *ConstraintSystem::adjustFunctionTypeForConcurrency(
FunctionType *fnType, Type baseType, ValueDecl *decl, DeclContext *dc,
unsigned numApplies, bool isMainDispatchQueue, OpenedTypeMap &replacements,
ConstraintLocatorBuilder locator) {

return swift::adjustFunctionTypeForConcurrency(
fnType, decl, dc, numApplies, isMainDispatchQueue,
GetClosureType{*this}, ClosureIsolatedByPreconcurrency{*this},
[&](Type type) {
auto *adjustedTy = swift::adjustFunctionTypeForConcurrency(
fnType, decl, dc, numApplies, isMainDispatchQueue, GetClosureType{*this},
ClosureIsolatedByPreconcurrency{*this}, [&](Type type) {
if (replacements.empty())
return type;

return openType(type, replacements, locator);
});

if (Context.LangOpts.hasFeature(Feature::InferSendableFromCaptures)) {
if (auto *FD = dyn_cast<AbstractFunctionDecl>(decl)) {
auto *DC = FD->getDeclContext();
// All global functions should be @Sendable
if (DC->isModuleScopeContext()) {
if (!adjustedTy->getExtInfo().isSendable()) {
adjustedTy =
adjustedTy->withExtInfo(adjustedTy->getExtInfo().withSendable());
}
} else if (isPartialApplication(getConstraintLocator(locator))) {
if (baseType &&
(baseType->is<AnyMetatypeType>() || baseType->isSendableType())) {
auto referenceTy = adjustedTy->getResult()->castTo<FunctionType>();
referenceTy =
referenceTy->withExtInfo(referenceTy->getExtInfo().withSendable())
->getAs<FunctionType>();

adjustedTy =
FunctionType::get(adjustedTy->getParams(), referenceTy,
adjustedTy->getExtInfo().withSendable());
}
}
}
}

return adjustedTy->castTo<FunctionType>();
}

/// For every parameter in \p type that has an error type, replace that
Expand Down Expand Up @@ -1781,9 +1807,9 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
auto origOpenedType = openedType;
if (!isRequirementOrWitness(locator)) {
unsigned numApplies = getNumApplications(value, false, functionRefKind);
openedType = cast<FunctionType>(adjustFunctionTypeForConcurrency(
origOpenedType, func, useDC, numApplies, false, replacements,
locator));
openedType = adjustFunctionTypeForConcurrency(
origOpenedType, /*baseType=*/Type(), func, useDC, numApplies, false,
replacements, locator);
}

// The reference implicitly binds 'self'.
Expand All @@ -1799,14 +1825,6 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
auto numLabelsToRemove = getNumRemovedArgumentLabels(
funcDecl, /*isCurriedInstanceReference=*/false, functionRefKind);

if (Context.LangOpts.hasFeature(Feature::InferSendableFromCaptures)) {
// All global functions should be @Sendable
if (funcDecl->getDeclContext()->isModuleScopeContext()) {
funcType =
funcType->withExtInfo(funcType->getExtInfo().withSendable());
}
}

auto openedType = openFunctionType(funcType, locator, replacements,
funcDecl->getDeclContext())
->removeArgumentLabels(numLabelsToRemove);
Expand All @@ -1818,9 +1836,9 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
if (!isRequirementOrWitness(locator)) {
unsigned numApplies = getNumApplications(
funcDecl, false, functionRefKind);
openedType = cast<FunctionType>(adjustFunctionTypeForConcurrency(
origOpenedType->castTo<FunctionType>(), funcDecl, useDC, numApplies,
false, replacements, locator));
openedType = adjustFunctionTypeForConcurrency(
origOpenedType->castTo<FunctionType>(), /*baseType=*/Type(), funcDecl,
useDC, numApplies, false, replacements, locator);
}

if (isForCodeCompletion() && openedType->hasError()) {
Expand Down Expand Up @@ -2788,20 +2806,6 @@ ConstraintSystem::getTypeOfMemberReference(
// FIXME: Verify ExtInfo state is correct, not working by accident.
FunctionType::ExtInfo info;

if (Context.LangOpts.hasFeature(Feature::InferSendableFromCaptures)) {
if (isPartialApplication(locator) &&
(resolvedBaseTy->is<AnyMetatypeType>() ||
baseOpenedTy->isSendableType())) {
// Add @Sendable to functions without conditional conformances
functionType =
functionType
->withExtInfo(functionType->getExtInfo().withSendable())
->getAs<FunctionType>();
}
// Unapplied values should always be Sendable
info = info.withSendable();
}

// We'll do other adjustment later, but we need to handle parameter
// isolation to avoid assertions.
if (fullFunctionType->getIsolation().isParameter())
Expand All @@ -2819,11 +2823,12 @@ ConstraintSystem::getTypeOfMemberReference(
unsigned numApplies = getNumApplications(
value, hasAppliedSelf, functionRefKind);
openedType = adjustFunctionTypeForConcurrency(
origOpenedType->castTo<AnyFunctionType>(), value, useDC, numApplies,
isMainDispatchQueueMember(locator), replacements, locator);
origOpenedType->castTo<FunctionType>(), resolvedBaseTy, value, useDC,
numApplies, isMainDispatchQueueMember(locator), replacements, locator);
} else if (auto subscript = dyn_cast<SubscriptDecl>(value)) {
openedType = adjustFunctionTypeForConcurrency(
origOpenedType->castTo<AnyFunctionType>(), subscript, useDC,
origOpenedType->castTo<FunctionType>(), resolvedBaseTy, subscript,
useDC,
/*numApplies=*/2, /*isMainDispatchQueue=*/false, replacements, locator);
} else if (auto var = dyn_cast<VarDecl>(value)) {
// Adjust the function's result type, since that's the Var's actual type.
Expand Down Expand Up @@ -2952,7 +2957,8 @@ Type ConstraintSystem::getEffectiveOverloadType(ConstraintLocator *locator,
// FIXME: Verify ExtInfo state is correct, not working by accident.
FunctionType::ExtInfo info;
type = adjustFunctionTypeForConcurrency(
FunctionType::get(indices, elementTy, info), subscript, useDC,
FunctionType::get(indices, elementTy, info), overload.getBaseType(),
subscript, useDC,
/*numApplies=*/1, /*isMainDispatchQueue=*/false, emptyReplacements,
locator);
} else if (auto var = dyn_cast<VarDecl>(decl)) {
Expand Down Expand Up @@ -3000,7 +3006,8 @@ Type ConstraintSystem::getEffectiveOverloadType(ConstraintLocator *locator,
decl, hasAppliedSelf, overload.getFunctionRefKind());

type = adjustFunctionTypeForConcurrency(
type->castTo<FunctionType>(), decl, useDC, numApplies,
type->castTo<FunctionType>(), overload.getBaseType(), decl,
useDC, numApplies,
/*isMainDispatchQueue=*/false, emptyReplacements, locator)
->getResult();
}
Expand Down
22 changes: 14 additions & 8 deletions lib/Sema/MiscDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC,
// Void to _ then warn, because that is redundant.
if (auto DAE = dyn_cast<DiscardAssignmentExpr>(destExpr)) {
if (auto CE = dyn_cast<CallExpr>(AE->getSrc())) {
if (isa_and_nonnull<FuncDecl>(CE->getCalledValue()) &&
if (getAsDecl<FuncDecl>(
CE->getCalledValue(/*skipFunctionConversions=*/true)) &&
CE->getType()->isVoid()) {
Ctx.Diags
.diagnose(DAE->getLoc(),
Expand Down Expand Up @@ -311,7 +312,7 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC,
// FIXME: Duplicate labels on enum payloads should be diagnosed
// when declared, not when called.
if (auto *CE = dyn_cast_or_null<CallExpr>(E)) {
auto calledValue = CE->getCalledValue();
auto calledValue = CE->getCalledValue(/*skipFunctionConversions=*/true);
if (calledValue && isa<EnumElementDecl>(calledValue)) {
auto *args = CE->getArgs();
SmallVector<Identifier, 4> scratch;
Expand Down Expand Up @@ -1464,6 +1465,9 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC,
if (auto dotSyntax = dyn_cast<DotSyntaxCallExpr>(fnExpr))
fnExpr = dotSyntax->getSemanticFn();

if (auto *FCE = dyn_cast<FunctionConversionExpr>(fnExpr))
fnExpr = FCE->getSubExpr();

auto DRE = dyn_cast<DeclRefExpr>(fnExpr);
if (!DRE || !DRE->getDecl()->isOperator())
return;
Expand Down Expand Up @@ -5093,7 +5097,7 @@ static void diagnoseUnintendedOptionalBehavior(const Expr *E,
return declRef->getDecl();

if (auto *apply = dyn_cast<ApplyExpr>(E)) {
auto *decl = apply->getCalledValue();
auto *decl = apply->getCalledValue(/*skipFunctionConversions=*/true);
if (isa_and_nonnull<AbstractFunctionDecl>(decl))
return decl;
}
Expand Down Expand Up @@ -5232,9 +5236,10 @@ static void diagnoseUnintendedOptionalBehavior(const Expr *E,

void diagnoseIfUnintendedInterpolation(CallExpr *segment,
UnintendedInterpolationKind kind) {
if (interpolationWouldBeUnintended(segment->getCalledValue(), kind))
if (interpolationWouldBeUnintended(
segment->getCalledValue(/*skipFunctionConversions=*/true), kind))
if (auto firstArg =
getFirstArgIfUnintendedInterpolation(segment->getArgs(), kind))
getFirstArgIfUnintendedInterpolation(segment->getArgs(), kind))
diagnoseUnintendedInterpolation(firstArg, kind);
}

Expand Down Expand Up @@ -5445,7 +5450,7 @@ static void maybeDiagnoseCallToKeyValueObserveMethod(const Expr *E,
KVOObserveCallWalker(ASTContext &ctx) : C(ctx) {}

void maybeDiagnoseCallExpr(CallExpr *expr) {
auto fn = expr->getCalledValue();
auto fn = expr->getCalledValue(/*skipFunctionConversions=*/true);
if (!fn)
return;
SmallVector<KeyPathExpr *, 1> keyPathArgs;
Expand Down Expand Up @@ -5583,9 +5588,10 @@ static void diagnoseComparisonWithNaN(const Expr *E, const DeclContext *DC) {
// Dig out the function declaration.
if (auto Fn = BE->getFn()) {
if (auto DSCE = dyn_cast<DotSyntaxCallExpr>(Fn)) {
comparisonDecl = DSCE->getCalledValue();
comparisonDecl =
DSCE->getCalledValue(/*skipFunctionConversions=*/true);
} else {
comparisonDecl = BE->getCalledValue();
comparisonDecl = BE->getCalledValue(/*skipFunctionConversions=*/true);
}
}

Expand Down
16 changes: 0 additions & 16 deletions test/Concurrency/sendable_functions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,3 @@ extension S: Sendable where T: Sendable {

@available(SwiftStdlib 5.1, *)
@MainActor @Sendable func globalActorFuncAsync() async { }

func test_initializer_ref() {
func test<T>(_: @Sendable (T, T) -> Array<T>) {
}

// Type of `initRef` should be @Sendable but due to implicitly injected autoclosure it isn't
let initRef = Array.init as (Int, Int) -> Array<Int>

// FIXME: incorrect non-Sendable diagnostic is produced due to `autoclosure` wrapping `Array.init`
test(initRef)
// expected-warning@-1 {{converting non-sendable function value to '@Sendable (Int, Int) -> Array<Int>' may introduce data races}}

// FIXME: Same here
test(Array.init as (Int, Int) -> Array<Int>)
// expected-warning@-1 {{converting non-sendable function value to '@Sendable (Int, Int) -> Array<Int>' may introduce data races}}
}
33 changes: 33 additions & 0 deletions test/Concurrency/sendable_methods.swift
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,36 @@ do {

let _: () -> Void = forward(Test.fn) // Ok
}


func test_initializer_ref() {
func test<T>(_: @Sendable (T, T) -> Array<T>) {
}

let initRef: @Sendable (Int, Int) -> Array<Int> = Array<Int>.init // Ok

test(initRef) // Ok
test(Array<Int>.init) // Ok
}

// rdar://119593407 - incorrect errors when partially applied member is accessed with InferSendableFromCaptures
do {
@MainActor struct ErrorHandler {
static func log(_ error: Error) {}
}

@MainActor final class Manager {
static var shared: Manager!

func test(_: @escaping @MainActor (Error) -> Void) {
}
}

@MainActor class Test {
func schedule() {
Task {
Manager.shared.test(ErrorHandler.log) // Ok (access is wrapped in an autoclosure)
}
}
}
}