Skip to content

Commit 636525e

Browse files
authored
Merge pull request #78171 from xedin/rdar-140300022
[TypeChecker/SILGen] Allow `any Sendable` to match `Any` while matching generic arguments
2 parents 51cce0d + a2f711c commit 636525e

20 files changed

+725
-20
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,6 +2798,9 @@ ERROR(types_not_inherited_decl,none,
27982798
ERROR(types_not_inherited_in_decl_ref,none,
27992799
"referencing %kind0 on %1 requires that %2 inherit from %3",
28002800
(const ValueDecl *, Type, Type, Type))
2801+
ERROR(cannot_reference_conditional_member_on_base_multiple_mismatches,none,
2802+
"cannot reference %kind0 on %1",
2803+
(const ValueDecl *, Type))
28012804
NOTE(where_requirement_failure_one_subst,none,
28022805
"where %0 = %1", (Type, Type))
28032806
NOTE(where_requirement_failure_both_subst,none,

include/swift/AST/Expr.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3600,6 +3600,24 @@ class ActorIsolationErasureExpr : public ImplicitConversionExpr {
36003600
}
36013601
};
36023602

3603+
/// UnsafeCastExpr - A special kind of conversion that performs an unsafe
3604+
/// bitcast from one type to the other.
3605+
///
3606+
/// Note that this is an unsafe operation and type-checker is allowed to
3607+
/// use this only in a limited number of cases like: `any Sendable` -> `Any`
3608+
/// conversions in some positions, covariant conversions of function and
3609+
/// function result types.
3610+
class UnsafeCastExpr : public ImplicitConversionExpr {
3611+
public:
3612+
UnsafeCastExpr(Expr *subExpr, Type type)
3613+
: ImplicitConversionExpr(ExprKind::UnsafeCast, subExpr, type) {
3614+
}
3615+
3616+
static bool classof(const Expr *E) {
3617+
return E->getKind() == ExprKind::UnsafeCast;
3618+
}
3619+
};
3620+
36033621
/// Extracts the isolation of a dynamically isolated function value.
36043622
class ExtractFunctionIsolationExpr : public Expr {
36053623
/// The function value expression from which to extract the

include/swift/AST/ExprNodes.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ ABSTRACT_EXPR(ImplicitConversion, Expr)
191191
EXPR(LinearFunctionExtractOriginal, ImplicitConversionExpr)
192192
EXPR(LinearToDifferentiableFunction, ImplicitConversionExpr)
193193
EXPR(ActorIsolationErasure, ImplicitConversionExpr)
194-
EXPR_RANGE(ImplicitConversion, Load, ActorIsolationErasure)
194+
EXPR(UnsafeCast, ImplicitConversionExpr)
195+
EXPR_RANGE(ImplicitConversion, Load, UnsafeCast)
195196
ABSTRACT_EXPR(ExplicitCast, Expr)
196197
ABSTRACT_EXPR(CheckedCast, ExplicitCastExpr)
197198
EXPR(ForcedCheckedCast, CheckedCastExpr)

include/swift/AST/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,9 @@ class alignas(1 << TypeAlignInBits) TypeBase
675675
/// Is this an existential containing only marker protocols?
676676
bool isMarkerExistential();
677677

678+
/// Is this `any Sendable` type?
679+
bool isSendableExistential();
680+
678681
bool isPlaceholder();
679682

680683
/// Returns true if this contextual type does not satisfy a conformance to

lib/AST/ASTDumper.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,6 +2834,12 @@ class PrintExpr : public ExprVisitor<PrintExpr, void, StringRef>,
28342834
printFoot();
28352835
}
28362836

2837+
void visitUnsafeCastExpr(UnsafeCastExpr *E, StringRef label) {
2838+
printCommon(E, "unsafe_cast_expr", label);
2839+
printRec(E->getSubExpr());
2840+
printFoot();
2841+
}
2842+
28372843
void visitExtractFunctionIsolationExpr(ExtractFunctionIsolationExpr *E,
28382844
StringRef label) {
28392845
printCommon(E, "extract_function_isolation", label);

lib/AST/ASTPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5343,6 +5343,9 @@ void PrintAST::visitLinearToDifferentiableFunctionExpr(swift::LinearToDifferenti
53435343
void PrintAST::visitActorIsolationErasureExpr(ActorIsolationErasureExpr *expr) {
53445344
}
53455345

5346+
void PrintAST::visitUnsafeCastExpr(UnsafeCastExpr *expr) {
5347+
}
5348+
53465349
void PrintAST::visitExtractFunctionIsolationExpr(ExtractFunctionIsolationExpr *expr) {
53475350
visit(expr->getFunctionExpr());
53485351
Printer << ".isolation";

lib/AST/Expr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ ConcreteDeclRef Expr::getReferencedDecl(bool stopAtParenExpr) const {
444444
PASS_THROUGH_REFERENCE(UnderlyingToOpaque, getSubExpr);
445445
PASS_THROUGH_REFERENCE(Unreachable, getSubExpr);
446446
PASS_THROUGH_REFERENCE(ActorIsolationErasure, getSubExpr);
447+
PASS_THROUGH_REFERENCE(UnsafeCast, getSubExpr);
447448
NO_REFERENCE(Coerce);
448449
NO_REFERENCE(ForcedCheckedCast);
449450
NO_REFERENCE(ConditionalCheckedCast);
@@ -813,6 +814,7 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const {
813814
case ExprKind::UnderlyingToOpaque:
814815
case ExprKind::Unreachable:
815816
case ExprKind::ActorIsolationErasure:
817+
case ExprKind::UnsafeCast:
816818
case ExprKind::TypeValue:
817819
// Implicit conversion nodes have no syntax of their own; defer to the
818820
// subexpression.
@@ -1043,6 +1045,7 @@ bool Expr::isValidParentOfTypeExpr(Expr *typeExpr) const {
10431045
case ExprKind::CurrentContextIsolation:
10441046
case ExprKind::ActorIsolationErasure:
10451047
case ExprKind::ExtractFunctionIsolation:
1048+
case ExprKind::UnsafeCast:
10461049
return false;
10471050
}
10481051

lib/AST/Type.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,17 @@ bool TypeBase::isMarkerExistential() {
163163
return true;
164164
}
165165

166+
bool TypeBase::isSendableExistential() {
167+
Type constraint = this;
168+
if (auto existential = constraint->getAs<ExistentialType>())
169+
constraint = existential->getConstraintType();
170+
171+
if (!constraint->isConstraintType())
172+
return false;
173+
174+
return constraint->getKnownProtocol() == KnownProtocolKind::Sendable;
175+
}
176+
166177
bool TypeBase::isPlaceholder() {
167178
return is<PlaceholderType>();
168179
}

lib/SILGen/SILGenBuilder.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,31 @@ ManagedValue SILGenBuilder::createUncheckedBitCast(SILLocation loc,
757757
return cloner.clone(cast);
758758
}
759759

760+
ManagedValue SILGenBuilder::createUncheckedForwardingCast(SILLocation loc,
761+
ManagedValue value,
762+
SILType type) {
763+
CleanupCloner cloner(*this, value);
764+
SILValue cast = createUncheckedForwardingCast(loc, value.getValue(), type);
765+
766+
// Currently createUncheckedBitCast only produces these
767+
// instructions. We assert here to make sure if this changes, this code is
768+
// updated.
769+
assert((isa<UncheckedTrivialBitCastInst>(cast) ||
770+
isa<UncheckedRefCastInst>(cast) ||
771+
isa<UncheckedValueCastInst>(cast) ||
772+
isa<ConvertFunctionInst>(cast)) &&
773+
"SILGenBuilder is out of sync with SILBuilder.");
774+
775+
// If we have a trivial inst, just return early.
776+
if (isa<UncheckedTrivialBitCastInst>(cast))
777+
return ManagedValue::forObjectRValueWithoutOwnership(cast);
778+
779+
// Otherwise, we forward the cleanup of the input value and place the cleanup
780+
// on the cast value since unchecked_ref_cast is "forwarding".
781+
value.forward(SGF);
782+
return cloner.clone(cast);
783+
}
784+
760785
ManagedValue SILGenBuilder::createOpenExistentialRef(SILLocation loc,
761786
ManagedValue original,
762787
SILType type) {

lib/SILGen/SILGenBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ class SILGenBuilder : public SILBuilder {
340340
ManagedValue createUncheckedBitCast(SILLocation loc, ManagedValue original,
341341
SILType type);
342342

343+
using SILBuilder::createUncheckedForwardingCast;
344+
ManagedValue createUncheckedForwardingCast(SILLocation loc,
345+
ManagedValue original,
346+
SILType type);
347+
343348
using SILBuilder::createOpenExistentialRef;
344349
ManagedValue createOpenExistentialRef(SILLocation loc, ManagedValue arg,
345350
SILType openedType);

lib/SILGen/SILGenExpr.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ namespace {
497497
RValue visitCovariantReturnConversionExpr(
498498
CovariantReturnConversionExpr *E,
499499
SGFContext C);
500+
RValue visitUnsafeCastExpr(UnsafeCastExpr *E, SGFContext C);
500501
RValue visitErasureExpr(ErasureExpr *E, SGFContext C);
501502
RValue visitAnyHashableErasureExpr(AnyHashableErasureExpr *E, SGFContext C);
502503
RValue visitForcedCheckedCastExpr(ForcedCheckedCastExpr *E,
@@ -2132,6 +2133,24 @@ RValue RValueEmitter::visitExtractFunctionIsolationExpr(
21322133
return RValue(SGF, E, result);
21332134
}
21342135

2136+
RValue RValueEmitter::visitUnsafeCastExpr(UnsafeCastExpr *E, SGFContext C) {
2137+
ManagedValue original = SGF.emitRValueAsSingleValue(E->getSubExpr());
2138+
SILType resultType = SGF.getLoweredType(E->getType());
2139+
2140+
if (resultType == original.getType())
2141+
return RValue(SGF, E, original);
2142+
2143+
ManagedValue result;
2144+
if (original.getType().isAddress()) {
2145+
ASSERT(resultType.isAddress());
2146+
result = SGF.B.createUncheckedAddrCast(E, original, resultType);
2147+
} else {
2148+
result = SGF.B.createUncheckedForwardingCast(E, original, resultType);
2149+
}
2150+
2151+
return RValue(SGF, E, result);
2152+
}
2153+
21352154
RValue RValueEmitter::visitErasureExpr(ErasureExpr *E, SGFContext C) {
21362155
if (auto result = tryEmitAsBridgingConversion(SGF, E, false, C)) {
21372156
return RValue(SGF, E, *result);

lib/Sema/CSApply.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7034,6 +7034,18 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
70347034
}
70357035
}
70367036

7037+
// `any Sendable` -> `Any` conversion is allowed in generic
7038+
// argument positions.
7039+
{
7040+
auto erasedFromType = fromType->stripConcurrency(
7041+
/*recursive=*/true, /*dropGlobalActor=*/false);
7042+
auto erasedToType = toType->stripConcurrency(
7043+
/*recursive=*/true, /*dropGlobalActor=*/false);
7044+
7045+
if (erasedFromType->isEqual(erasedToType))
7046+
return cs.cacheType(new (ctx) UnsafeCastExpr(expr, toType));
7047+
}
7048+
70377049
auto &err = llvm::errs();
70387050
err << "fromType->getCanonicalType() = ";
70397051
fromType->getCanonicalType()->dump(err);

lib/Sema/CSDiagnostics.cpp

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -911,16 +911,44 @@ bool GenericArgumentsMismatchFailure::diagnoseAsError() {
911911
// before pointer types could be compared.
912912
auto locator = getLocator();
913913
auto path = locator->getPath();
914-
unsigned toDrop = 0;
915-
for (const auto &elt : llvm::reverse(path)) {
916-
if (!elt.is<LocatorPathElt::OptionalPayload>())
917-
break;
918914

919-
// Disregard optional payload element to look at its source.
920-
++toDrop;
915+
// If there are generic types involved, we need to find
916+
// the outermost generic types and report on them instead
917+
// of their arguments.
918+
// For example:
919+
//
920+
// <expr> -> contextual type
921+
// -> generic type S<[Int]>
922+
// -> generic type S<[String]>
923+
// -> generic argument #0
924+
//
925+
// Is going to have from/to types as `[Int]` and `[String]` but
926+
// the diagnostic should mention `S<[Int]>` and `S<[String]>`
927+
// because it refers to a contextual type location.
928+
if (locator->isLastElement<LocatorPathElt::GenericArgument>()) {
929+
for (unsigned i = 0; i < path.size(); ++i) {
930+
if (auto genericType = path[i].getAs<LocatorPathElt::GenericType>()) {
931+
ASSERT(i + 1 < path.size());
932+
933+
fromType = resolveType(genericType->getType());
934+
toType = resolveType(
935+
path[i + 1].castTo<LocatorPathElt::GenericType>().getType());
936+
break;
937+
}
938+
}
921939
}
922940

923-
path = path.drop_back(toDrop);
941+
while (!path.empty()) {
942+
auto last = path.back();
943+
if (last.is<LocatorPathElt::OptionalPayload>() ||
944+
last.is<LocatorPathElt::GenericType>() ||
945+
last.is<LocatorPathElt::GenericArgument>()) {
946+
path = path.drop_back();
947+
continue;
948+
}
949+
950+
break;
951+
}
924952

925953
std::optional<Diag<Type, Type>> diagnostic;
926954
if (path.empty()) {
@@ -1016,6 +1044,34 @@ bool GenericArgumentsMismatchFailure::diagnoseAsError() {
10161044
break;
10171045
}
10181046

1047+
case ConstraintLocator::Member: {
1048+
auto *memberLoc = getConstraintLocator(anchor, path);
1049+
auto selectedOverload = getOverloadChoiceIfAvailable(memberLoc);
1050+
if (!selectedOverload)
1051+
return false;
1052+
1053+
auto baseTy = selectedOverload->choice.getBaseType()->getRValueType();
1054+
auto *memberRef = selectedOverload->choice.getDecl();
1055+
1056+
if (Mismatches.size() == 1) {
1057+
auto mismatchIdx = Mismatches.front();
1058+
auto actualArgTy = getActual()->getGenericArgs()[mismatchIdx];
1059+
auto requiredArgTy = getRequired()->getGenericArgs()[mismatchIdx];
1060+
1061+
emitDiagnostic(diag::types_not_equal_in_decl_ref, memberRef, baseTy,
1062+
actualArgTy, requiredArgTy);
1063+
emitDiagnosticAt(memberRef, diag::decl_declared_here, memberRef);
1064+
return true;
1065+
}
1066+
1067+
emitDiagnostic(
1068+
diag::cannot_reference_conditional_member_on_base_multiple_mismatches,
1069+
memberRef, baseTy);
1070+
emitDiagnosticAt(memberRef, diag::decl_declared_here, memberRef);
1071+
emitNotesForMismatches();
1072+
return true;
1073+
}
1074+
10191075
default:
10201076
break;
10211077
}

0 commit comments

Comments
 (0)