Skip to content

[6.0] Fix function subtyping rules for sending #74428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 17, 2024
Merged
13 changes: 13 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -7952,6 +7952,19 @@ ERROR(sending_only_on_parameters_and_results, none,
"'sending' may only be used on parameters and results", ())
ERROR(sending_cannot_be_applied_to_tuple_elt, none,
"'sending' cannot be applied to tuple elements", ())
ERROR(sending_function_wrong_sending,none,
"converting a value of type %0 to type %1 risks causing data races",
(Type, Type))
NOTE(sending_function_param_with_sending_param_note, none,
"converting a function typed value with a sending parameter to one "
"without risks allowing actor-isolated values to escape their isolation "
"domain as an argument to an invocation of value",
())
NOTE(sending_function_result_with_sending_param_note, none,
"converting a function typed value without a sending result as one with "
"risks allowing actor-isolated values to escape their "
"isolation domain through a result of an invocation of value",
())

#define UNDEFINE_DIAGNOSTIC_MACROS
#include "DefineDiagnosticMacros.h"
3 changes: 3 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3237,6 +3237,9 @@ class AnyFunctionType : public TypeBase {
/// Whether the parameter is 'isolated'.
bool isIsolated() const { return Flags.isIsolated(); }

/// Whether or not the parameter is 'sending'.
bool isSending() const { return Flags.isSending(); }

/// Whether the parameter is 'isCompileTimeConst'.
bool isCompileTimeConst() const { return Flags.isCompileTimeConst(); }

Expand Down
54 changes: 54 additions & 0 deletions include/swift/Sema/CSFix.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,17 @@ enum class FixKind : uint8_t {
/// Ignore situations when key path subscript index gets passed an invalid
/// type as an argument (something that is not a key path).
IgnoreKeyPathSubscriptIndexMismatch,

/// Ignore the following situations:
///
/// 1. Where we have a function that expects a function typed parameter
/// without a sendable parameter but is passed a function type with a sending
/// parameter.
///
/// 2. Where we have a function that expects a function typed parameter with a
/// sending result, but is passed a function typeed parameter without a
/// sending result.
AllowSendingMismatch,
};

class ConstraintFix {
Expand Down Expand Up @@ -2619,6 +2630,49 @@ class TreatEphemeralAsNonEphemeral final : public AllowArgumentMismatch {
}
};

/// Error if a user passes let f: (sending T) -> () as a (T) -> ().
///
/// This prevents data races since f assumes its parameter if the parameter is
/// non-Sendable is safe to transfer onto other situations. The caller though
/// that this is being sent to does not enforce that invariants within its body.
class AllowSendingMismatch final : public ContextualMismatch {
public:
enum class Kind {
Parameter,
Result,
};

private:
Kind kind;

AllowSendingMismatch(ConstraintSystem &cs, Type argType, Type paramType,
ConstraintLocator *locator, Kind kind,
FixBehavior fixBehavior)
: ContextualMismatch(cs, FixKind::AllowSendingMismatch, argType,
paramType, locator, fixBehavior),
kind(kind) {}

public:
std::string getName() const override {
return "treat a function argument with sending parameter as a function "
"argument without sending parameters";
}

bool diagnose(const Solution &solution, bool asNote = false) const override;

bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override {
return diagnose(*commonFixes.front().first);
}

static AllowSendingMismatch *create(ConstraintSystem &cs,
ConstraintLocator *locator, Type srcType,
Type dstType, Kind kind);

static bool classof(const ConstraintFix *fix) {
return fix->getKind() == FixKind::AllowSendingMismatch;
}
};

class SpecifyBaseTypeForContextualMember final : public ConstraintFix {
DeclNameRef MemberName;

Expand Down
2 changes: 2 additions & 0 deletions lib/Parse/ParseType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,8 @@ bool Parser::canParseType() {
consumeToken();
} else if (Tok.isContextualKeyword("each")) {
consumeToken();
} else if (Tok.isContextualKeyword("sending")) {
consumeToken();
}

switch (Tok.getKind()) {
Expand Down
4 changes: 3 additions & 1 deletion lib/Sema/AssociatedTypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,9 @@ AssociatedTypeInference::computeFailureTypeWitness(
// it.
for (const auto &witness : valueWitnesses) {
if (isAsyncIteratorProtocolNext(witness.first)) {
if (auto witnessFunc = dyn_cast<AbstractFunctionDecl>(witness.second)) {
// We use a dyn_cast_or_null since we can get a nullptr here if we fail to
// match a witness. In such a case, we should just fail here.
if (auto witnessFunc = dyn_cast_or_null<AbstractFunctionDecl>(witness.second)) {
auto thrownError = witnessFunc->getEffectiveThrownErrorType();

// If it doesn't throw, Failure == Never.
Expand Down
18 changes: 18 additions & 0 deletions lib/Sema/CSDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7918,6 +7918,24 @@ bool NonEphemeralConversionFailure::diagnoseAsError() {
return true;
}

bool SendingOnFunctionParameterMismatchFail::diagnoseAsError() {
emitDiagnosticAt(getLoc(), diag::sending_function_wrong_sending,
getFromType(), getToType())
.warnUntilSwiftVersion(6);
emitDiagnosticAt(getLoc(),
diag::sending_function_param_with_sending_param_note);
return true;
}

bool SendingOnFunctionResultMismatchFailure::diagnoseAsError() {
emitDiagnosticAt(getLoc(), diag::sending_function_wrong_sending,
getFromType(), getToType())
.warnUntilSwiftVersion(6);
emitDiagnosticAt(getLoc(),
diag::sending_function_result_with_sending_param_note);
return true;
}

bool AssignmentTypeMismatchFailure::diagnoseMissingConformance() const {
auto srcType = getFromType();
auto dstType = getToType()->lookThroughAllOptionalTypes();
Expand Down
22 changes: 22 additions & 0 deletions lib/Sema/CSDiagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,28 @@ class NonEphemeralConversionFailure final : public ArgumentMismatchFailure {
void emitSuggestionNotes() const;
};

class SendingOnFunctionParameterMismatchFail final : public ContextualFailure {
public:
SendingOnFunctionParameterMismatchFail(const Solution &solution, Type srcType,
Type dstType,
ConstraintLocator *locator,
FixBehavior fixBehavior)
: ContextualFailure(solution, srcType, dstType, locator, fixBehavior) {}

bool diagnoseAsError() override;
};

class SendingOnFunctionResultMismatchFailure final : public ContextualFailure {
public:
SendingOnFunctionResultMismatchFailure(const Solution &solution, Type srcType,
Type dstType,
ConstraintLocator *locator,
FixBehavior fixBehavior)
: ContextualFailure(solution, srcType, dstType, locator, fixBehavior) {}

bool diagnoseAsError() override;
};

class AssignmentTypeMismatchFailure final : public ContextualFailure {
public:
AssignmentTypeMismatchFailure(const Solution &solution,
Expand Down
28 changes: 28 additions & 0 deletions lib/Sema/CSFix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,34 @@ std::string TreatEphemeralAsNonEphemeral::getName() const {
return name;
}

bool AllowSendingMismatch::diagnose(const Solution &solution,
bool asNote) const {
switch (kind) {
case Kind::Parameter: {
SendingOnFunctionParameterMismatchFail failure(
solution, getFromType(), getToType(), getLocator(), fixBehavior);
return failure.diagnose(asNote);
}
case Kind::Result: {
SendingOnFunctionResultMismatchFailure failure(
solution, getFromType(), getToType(), getLocator(), fixBehavior);
return failure.diagnose(asNote);
}
}
llvm_unreachable("Covered switch isn't covered?!");
}

AllowSendingMismatch *AllowSendingMismatch::create(ConstraintSystem &cs,
ConstraintLocator *locator,
Type srcType, Type dstType,
Kind kind) {
auto fixBehavior = cs.getASTContext().LangOpts.isSwiftVersionAtLeast(6)
? FixBehavior::Error
: FixBehavior::DowngradeToWarning;
return new (cs.getAllocator())
AllowSendingMismatch(cs, srcType, dstType, locator, kind, fixBehavior);
}

bool SpecifyBaseTypeForContextualMember::diagnose(const Solution &solution,
bool asNote) const {
MissingContextualBaseInMemberRefFailure failure(solution, MemberName,
Expand Down
11 changes: 6 additions & 5 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1636,10 +1636,9 @@ namespace {
}

Type
resolveTypeReferenceInExpression(TypeRepr *repr, TypeResolverContext resCtx,
resolveTypeReferenceInExpression(TypeRepr *repr,
TypeResolutionOptions options,
const ConstraintLocatorBuilder &locator) {
TypeResolutionOptions options(resCtx);

// Introduce type variables for unbound generics.
const auto genericOpener = OpenUnboundGenericType(CS, locator);
const auto placeholderHandler = HandlePlaceholderType(CS, locator);
Expand Down Expand Up @@ -2528,9 +2527,11 @@ namespace {
return declaredTy;
}

auto options =
TypeResolutionOptions(TypeResolverContext::InExpression);
options.setContext(TypeResolverContext::ClosureExpr);
const auto resolvedTy = resolveTypeReferenceInExpression(
closure->getExplicitResultTypeRepr(),
TypeResolverContext::InExpression, resultLocator);
closure->getExplicitResultTypeRepr(), options, resultLocator);
if (resolvedTy)
return resolvedTy;
}
Expand Down
36 changes: 32 additions & 4 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,16 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
return getTypeMatchFailure(locator);
}

// () -> sending T can be a subtype of () -> T... but not vis-a-versa.
if (func1->hasSendingResult() != func2->hasSendingResult() &&
(!func1->hasSendingResult() || kind < ConstraintKind::Subtype)) {
auto *fix = AllowSendingMismatch::create(
*this, getConstraintLocator(locator), func1, func2,
AllowSendingMismatch::Kind::Result);
if (recordFix(fix))
return getTypeMatchFailure(locator);
}

if (!matchFunctionIsolations(func1, func2, kind, flags, locator))
return getTypeMatchFailure(locator);

Expand Down Expand Up @@ -3666,6 +3676,17 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
return getTypeMatchFailure(argumentLocator);
}

// Do not allow for functions that expect a sending parameter to match
// with a function that expects a non-sending parameter.
if (func1Param.getParameterFlags().isSending() &&
!func2Param.getParameterFlags().isSending()) {
auto *fix = AllowSendingMismatch::create(
*this, getConstraintLocator(argumentLocator), func1, func2,
AllowSendingMismatch::Kind::Parameter);
if (recordFix(fix))
return getTypeMatchFailure(argumentLocator);
}

// FIXME: We should check value ownership too, but it's not completely
// trivial because of inout-to-pointer conversions.

Expand Down Expand Up @@ -11770,10 +11791,10 @@ bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar,
if (contextualParam->isIsolated() && !flags.isIsolated() && paramDecl)
isolatedParams.insert(paramDecl);

param =
param.withFlags(flags.withInOut(contextualParam->isInOut())
.withVariadic(contextualParam->isVariadic())
.withIsolated(contextualParam->isIsolated()));
param = param.withFlags(flags.withInOut(contextualParam->isInOut())
.withVariadic(contextualParam->isVariadic())
.withIsolated(contextualParam->isIsolated())
.withSending(contextualParam->isSending()));
}
}

Expand Down Expand Up @@ -11900,6 +11921,12 @@ bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar,
closureExtInfo = closureExtInfo.withSendable();
}

// Propagate sending result from the contextual type to the closure.
if (auto contextualFnType = contextualType->getAs<FunctionType>()) {
if (contextualFnType->hasExtInfo() && contextualFnType->hasSendingResult())
closureExtInfo = closureExtInfo.withSendingResult();
}

// Isolated parameters override any other kind of isolation we might infer.
if (hasIsolatedParam) {
closureExtInfo = closureExtInfo.withIsolation(
Expand Down Expand Up @@ -15098,6 +15125,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint(
}
}

case FixKind::AllowSendingMismatch:
case FixKind::InsertCall:
case FixKind::RemoveReturn:
case FixKind::RemoveAddressOf:
Expand Down
17 changes: 17 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,15 @@ RequirementMatch swift::matchWitness(
reqTypeIsIUO != witnessTypeIsIUO)
return RequirementMatch(witness, MatchKind::TypeConflict, witnessType);

// If our requirement says that it has a sending result, then our witness
// must also have a sending result since otherwise, in generic contexts,
// we would be returning non-disconnected values as disconnected.
if (dc->getASTContext().LangOpts.isSwiftVersionAtLeast(6)) {
if (reqFnType->hasExtInfo() && reqFnType->hasSendingResult() &&
(!witnessFnType->hasExtInfo() || !witnessFnType->hasSendingResult()))
return RequirementMatch(witness, MatchKind::TypeConflict, witnessType);
}

if (auto result = matchTypes(std::get<0>(types), std::get<1>(types))) {
return std::move(result.value());
}
Expand Down Expand Up @@ -775,6 +784,14 @@ RequirementMatch swift::matchWitness(
if (reqParams[i].isInOut() != witnessParams[i].isInOut())
return RequirementMatch(witness, MatchKind::TypeConflict, witnessType);

// If we have a requirement without sending and our witness expects a
// sending parameter, error.
if (dc->getASTContext().isSwiftVersionAtLeast(6)) {
if (!reqParams[i].getParameterFlags().isSending() &&
witnessParams[i].getParameterFlags().isSending())
return RequirementMatch(witness, MatchKind::TypeConflict, witnessType);
}

auto reqParamDecl = reqParamList->get(i);
auto witnessParamDecl = witnessParamList->get(i);

Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/TypeCheckType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4971,7 +4971,8 @@ TypeResolver::resolveSendingTypeRepr(SendingTypeRepr *repr,
return ErrorType::get(getASTContext());
}

if (!options.is(TypeResolverContext::FunctionResult) &&
if (!options.is(TypeResolverContext::ClosureExpr) &&
!options.is(TypeResolverContext::FunctionResult) &&
(!options.is(TypeResolverContext::FunctionInput) ||
options.hasBase(TypeResolverContext::EnumElementDecl))) {
diagnoseInvalid(repr, repr->getSpecifierLoc(),
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckType.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ enum class TypeResolverContext : uint8_t {
/// Whether we are checking the parameter list of a subscript.
SubscriptDecl,

/// Whether we are checking the parameter list of a closure.
/// Whether we are checking the parameter list or result of a closure.
ClosureExpr,

/// Whether we are in the input type of a function, or under one level of
Expand Down
Loading