Skip to content

Commit 4b63624

Browse files
authored
Merge pull request #69502 from ktoso/pick-wip-dont-crash-missing-conformance-param
🍒[5.10][Distributed] Don't crash in thunk generation when missing SR conformance
2 parents 4312c32 + af211dd commit 4b63624

9 files changed

+189
-228
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor);
5050
/// Similar to `getDistributedSerializationRequirementType`, however, from the
5151
/// perspective of a concrete function. This way we're able to get the
5252
/// serialization requirement for specific members, also in protocols.
53-
Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member);
53+
Type getSerializationRequirementTypesForMember(
54+
ValueDecl *member, llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements);
5455

5556
/// Get specific 'SerializationRequirement' as defined in 'nominal'
5657
/// type, which must conform to the passed 'protocol' which is expected
@@ -91,19 +92,13 @@ llvm::SmallPtrSet<ProtocolDecl *, 2>
9192
getDistributedSerializationRequirementProtocols(
9293
NominalTypeDecl *decl, ProtocolDecl* protocol);
9394

94-
/// Desugar and flatten the `SerializationRequirement` type into a set of
95-
/// specific protocol declarations.
96-
llvm::SmallPtrSet<ProtocolDecl *, 2>
97-
flattenDistributedSerializationTypeToRequiredProtocols(
98-
TypeBase *serializationRequirement);
99-
10095
/// Check if the `allRequirements` represent *exactly* the
10196
/// `Encodable & Decodable` (also known as `Codable`) requirement.
10297
///
10398
/// If so, we can emit slightly nicer diagnostics.
10499
bool checkDistributedSerializationRequirementIsExactlyCodable(
105100
ASTContext &C,
106-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements);
101+
Type type);
107102

108103
/// Get the `SerializationRequirement`, explode it into the specific
109104
/// protocol requirements and insert them into `requirements`.
@@ -120,15 +115,6 @@ getDistributedSerializationRequirements(
120115
ProtocolDecl *protocol,
121116
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);
122117

123-
/// Given any set of generic requirements, locate those which are about the
124-
/// `SerializationRequirement`. Those need to be applied in the parameter and
125-
/// return type checking of distributed targets.
126-
llvm::SmallPtrSet<ProtocolDecl *, 2>
127-
extractDistributedSerializationRequirements(
128-
ASTContext &C, ArrayRef<Requirement> allRequirements);
129-
130118
}
131119

132-
// ==== ------------------------------------------------------------------------
133-
134120
#endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */

lib/AST/DistributedDecl.cpp

Lines changed: 42 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member)
9595
llvm_unreachable("Unable to fetch ActorSystem type!");
9696
}
9797

98-
Type swift::getConcreteReplacementForMemberSerializationRequirement(
99-
ValueDecl *member) {
98+
Type swift::getSerializationRequirementTypesForMember(
99+
ValueDecl *member,
100+
llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements) {
100101
auto &C = member->getASTContext();
101102
auto *DC = member->getDeclContext();
102103
auto DA = C.getDistributedActorDecl();
@@ -106,17 +107,21 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
106107
return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl());
107108
}
108109

109-
/// === Maybe the value is declared in a protocol?
110-
if (auto protocol = DC->getSelfProtocolDecl()) {
110+
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
111+
->getDeclaredInterfaceType();
112+
113+
if (DC->getSelfProtocolDecl()) {
111114
GenericSignature signature;
112115
if (auto *genericContext = member->getAsGenericContext()) {
113116
signature = genericContext->getGenericSignature();
114117
} else {
115118
signature = DC->getGenericSignatureOfContext();
116119
}
117120

118-
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
119-
->getDeclaredInterfaceType();
121+
// Also store all `SerializationRequirement : SomeProtocol` requirements
122+
for (auto proto: signature->getRequiredProtocols(SerReqAssocType)) {
123+
serializationRequirements.insert(proto);
124+
}
120125

121126
// Note that this may be null, e.g. if we're a distributed func inside
122127
// a protocol that did not declare a specific actor system requirement.
@@ -178,13 +183,7 @@ Type swift::getDistributedActorSystemResultHandlerType(
178183
auto module = system->getParentModule();
179184
Type selfType = system->getSelfInterfaceType();
180185
auto conformance = module->lookupConformance(selfType, DAS);
181-
auto witness =
182-
conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler);
183-
if (auto alias = dyn_cast<TypeAliasType>(witness.getPointer())) {
184-
return alias->getDecl()->getUnderlyingType();
185-
} else {
186-
return witness;
187-
}
186+
return conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler);
188187
}
189188

190189
Type swift::getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system) {
@@ -346,63 +345,39 @@ swift::getDistributedSerializationRequirements(
346345
if (existentialRequirementTy->isAny())
347346
return true; // we're done here, any means there are no requirements
348347

349-
if (!existentialRequirementTy->isExistentialType()) {
350-
// SerializationRequirement must be an existential type
351-
return false;
352-
}
353-
354-
ExistentialType *serialReqType = existentialRequirementTy
355-
->castTo<ExistentialType>();
348+
auto *serialReqType = existentialRequirementTy->getAs<ExistentialType>();
356349
if (!serialReqType || serialReqType->hasError()) {
357350
return false;
358351
}
359352

360-
auto desugaredTy = serialReqType->getConstraintType()->getDesugaredType();
361-
auto flattenedRequirements =
362-
flattenDistributedSerializationTypeToRequiredProtocols(
363-
desugaredTy);
364-
for (auto p : flattenedRequirements) {
353+
auto layout = serialReqType->getExistentialLayout();
354+
for (auto p : layout.getProtocols()) {
365355
requirementProtos.insert(p);
366356
}
367357

368358
return true;
369359
}
370360

371-
llvm::SmallPtrSet<ProtocolDecl *, 2>
372-
swift::flattenDistributedSerializationTypeToRequiredProtocols(
373-
TypeBase *serializationRequirement) {
374-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
375-
if (auto composition =
376-
serializationRequirement->getAs<ProtocolCompositionType>()) {
377-
for (auto member : composition->getMembers()) {
378-
if (auto comp = member->getAs<ProtocolCompositionType>()) {
379-
for (auto protocol :
380-
flattenDistributedSerializationTypeToRequiredProtocols(comp)) {
381-
serializationReqs.insert(protocol);
382-
}
383-
} else if (auto *protocol = member->getAs<ProtocolType>()) {
384-
serializationReqs.insert(protocol->getDecl());
385-
}
386-
}
387-
} else {
388-
auto protocol = serializationRequirement->castTo<ProtocolType>()->getDecl();
389-
serializationReqs.insert(protocol);
390-
}
391-
392-
return serializationReqs;
393-
}
394-
395361
bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
396362
ASTContext &C,
397-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
363+
Type type) {
364+
if (!type)
365+
return false;
366+
367+
if (type->hasError())
368+
return false;
369+
398370
auto encodable = C.getProtocol(KnownProtocolKind::Encodable);
399371
auto decodable = C.getProtocol(KnownProtocolKind::Decodable);
400372

401-
if (allRequirements.size() != 2)
373+
auto layout = type->getExistentialLayout();
374+
auto protocols = layout.getProtocols();
375+
376+
if (protocols.size() != 2)
402377
return false;
403378

404-
return allRequirements.count(encodable) &&
405-
allRequirements.count(decodable);
379+
return std::count(protocols.begin(), protocols.end(), encodable) == 1 &&
380+
std::count(protocols.begin(), protocols.end(), decodable) == 1;
406381
}
407382

408383
/******************************************************************************/
@@ -571,25 +546,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
571546

572547
// --- Check requirement: conforms_to: Act DistributedActor
573548
auto actorReq = requirements[0];
574-
auto distActorTy = C.getProtocol(KnownProtocolKind::DistributedActor)
575-
->getInterfaceType()
576-
->getMetatypeInstanceType();
577549
if (actorReq.getKind() != RequirementKind::Conformance) {
578550
return false;
579551
}
580-
if (!actorReq.getSecondType()->isEqual(distActorTy)) {
552+
if (!actorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::DistributedActor)) {
581553
return false;
582554
}
583555

584556
// --- Check requirement: conforms_to: Err Error
585557
auto errorReq = requirements[1];
586-
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
587-
->getInterfaceType()
588-
->getMetatypeInstanceType();
589558
if (errorReq.getKind() != RequirementKind::Conformance) {
590559
return false;
591560
}
592-
if (!errorReq.getSecondType()->isEqual(errorTy)) {
561+
if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) {
593562
return false;
594563
}
595564

@@ -604,10 +573,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
604573
assert(ResParam && "Non void function, yet no Res generic parameter found");
605574
if (auto func = dyn_cast<FuncDecl>(this)) {
606575
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
607-
->getMetatypeInstanceType()
608-
->getDesugaredType();
576+
->getMetatypeInstanceType();
609577
auto resultParamType = func->mapTypeIntoContext(
610-
ResParam->getInterfaceType()->getMetatypeInstanceType());
578+
ResParam->getDeclaredInterfaceType());
611579
// The result of the function must be the `Res` generic argument.
612580
if (!resultType->isEqual(resultParamType)) {
613581
return false;
@@ -803,12 +771,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
803771

804772
// the <Value> of the RemoteCallArgument<Value>
805773
auto remoteCallArgValueGenericTy =
806-
mapTypeIntoContext(argGenericParams[0]->getInterfaceType())
807-
->getDesugaredType()
808-
->getMetatypeInstanceType();
774+
mapTypeIntoContext(argGenericParams[0]->getDeclaredInterfaceType());
809775
// expected (the <Value> from the recordArgument<Value>)
810776
auto expectedGenericParamTy = mapTypeIntoContext(
811-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
777+
ArgumentParam->getDeclaredInterfaceType());
812778

813779
if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) {
814780
return false;
@@ -938,11 +904,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con
938904
// ...
939905

940906
auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType())
941-
->getMetatypeInstanceType()
942-
->getDesugaredType();
907+
->getMetatypeInstanceType();
943908

944909
auto resultParamType = func->mapTypeIntoContext(
945-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
910+
ArgumentParam->getDeclaredInterfaceType());
946911

947912
// The result of the function must be the `Res` generic argument.
948913
if (!resultType->isEqual(resultParamType)) {
@@ -1052,13 +1017,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
10521017

10531018
// --- Check requirement: conforms_to: Err Error
10541019
auto errorReq = requirements[0];
1055-
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
1056-
->getInterfaceType()
1057-
->getMetatypeInstanceType();
10581020
if (errorReq.getKind() != RequirementKind::Conformance) {
10591021
return false;
10601022
}
1061-
if (!errorReq.getSecondType()->isEqual(errorTy)) {
1023+
if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) {
10621024
return false;
10631025
}
10641026

@@ -1145,10 +1107,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c
11451107
// --- Check: Argument: SerializationRequirement
11461108
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
11471109
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
1148-
->getMetatypeInstanceType()
1149-
->getDesugaredType();
1110+
->getMetatypeInstanceType();
11501111
auto resultParamType = func->mapTypeIntoContext(
1151-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
1112+
ArgumentParam->getDeclaredInterfaceType());
11521113
// The result of the function must be the `Res` generic argument.
11531114
if (!resultType->isEqual(resultParamType)) {
11541115
return false;
@@ -1243,11 +1204,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12431204
// === Check generic parameters in detail
12441205
// --- Check: Argument: SerializationRequirement
12451206
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
1246-
auto argumentType = func->mapTypeIntoContext(valueParam->getInterfaceType())
1247-
->getMetatypeInstanceType()
1248-
->getDesugaredType();
1207+
auto argumentType = func->mapTypeIntoContext(
1208+
valueParam->getInterfaceType()->getMetatypeInstanceType());
12491209
auto resultParamType = func->mapTypeIntoContext(
1250-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
1210+
ArgumentParam->getDeclaredInterfaceType());
12511211
// The result of the function must be the `Res` generic argument.
12521212
if (!argumentType->isEqual(resultParamType)) {
12531213
return false;
@@ -1268,50 +1228,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12681228
return true;
12691229
}
12701230

1271-
llvm::SmallPtrSet<ProtocolDecl *, 2>
1272-
swift::extractDistributedSerializationRequirements(
1273-
ASTContext &C, ArrayRef<Requirement> allRequirements) {
1274-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
1275-
auto DA = C.getDistributedActorDecl();
1276-
auto daSerializationReqAssocType =
1277-
DA->getAssociatedType(C.Id_SerializationRequirement);
1278-
auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType();
1279-
1280-
for (auto req : allRequirements) {
1281-
if (req.getSecondType()->isAny()) {
1282-
continue;
1283-
}
1284-
if (!req.getFirstType()->hasDependentMember())
1285-
continue;
1286-
1287-
if (auto dependentMemberType =
1288-
req.getFirstType()->castTo<DependentMemberType>()) {
1289-
auto dependentTy =
1290-
dependentMemberType->getAssocType()->getInterfaceType();
1291-
1292-
if (dependentTy->isEqual(daSystemSerializationReqTy)) {
1293-
auto requirementProto = req.getSecondType();
1294-
if (auto proto = dyn_cast_or_null<ProtocolDecl>(
1295-
requirementProto->getAnyNominal())) {
1296-
serializationReqs.insert(proto);
1297-
} else {
1298-
auto serialReqType = requirementProto->castTo<ExistentialType>()
1299-
->getConstraintType()
1300-
->getDesugaredType();
1301-
auto flattenedRequirements =
1302-
flattenDistributedSerializationTypeToRequiredProtocols(
1303-
serialReqType);
1304-
for (auto p : flattenedRequirements) {
1305-
serializationReqs.insert(p);
1306-
}
1307-
}
1308-
}
1309-
}
1310-
}
1311-
1312-
return serializationReqs;
1313-
}
1314-
13151231
/******************************************************************************/
13161232
/********************** Distributed Functions *********************************/
13171233
/******************************************************************************/

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator,
842842
if (!distributedTarget->isDistributed())
843843
return nullptr;
844844
}
845-
846845
assert(distributedTarget);
847846

847+
// This evaluation type-check by now was already computed and cached;
848+
// We need to check in order to avoid emitting a THUNK for a distributed func
849+
// which had errors; as the thunk then may also cause un-addressable issues and confusion.
850+
if (swift::checkDistributedFunction(distributedTarget)) {
851+
return nullptr;
852+
}
853+
848854
auto &C = distributedTarget->getASTContext();
849855

850856
if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) {

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2067,7 +2067,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) {
20672067
return (prop &&
20682068
prop->isFinal() &&
20692069
isa<ClassDecl>(prop->getDeclContext()) &&
2070-
cast<ClassDecl>(prop->getDeclContext())->isActor() &&
2070+
cast<ClassDecl>(prop->getDeclContext())->isAnyActor() &&
20712071
!prop->isStatic() &&
20722072
prop->getName() == ctx.Id_unownedExecutor &&
20732073
prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl());

0 commit comments

Comments
 (0)