Skip to content

Commit 924bbb9

Browse files
committed
[Distributed] Don't crash in thunk generation when missing SR conformance
1 parent 7b3c47f commit 924bbb9

8 files changed

+90
-17
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ getDistributedSerializationRequirements(
123123
/// Given any set of generic requirements, locate those which are about the
124124
/// `SerializationRequirement`. Those need to be applied in the parameter and
125125
/// return type checking of distributed targets.
126-
llvm::SmallPtrSet<ProtocolDecl *, 2>
126+
void
127127
extractDistributedSerializationRequirements(
128-
ASTContext &C, ArrayRef<Requirement> allRequirements);
128+
ASTContext &C,
129+
ArrayRef<Requirement> allRequirements,
130+
llvm::SmallPtrSet<ProtocolDecl *, 2> &into);
129131

130132
}
131133

lib/AST/DistributedDecl.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,10 +1268,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12681268
return true;
12691269
}
12701270

1271-
llvm::SmallPtrSet<ProtocolDecl *, 2>
1271+
void
12721272
swift::extractDistributedSerializationRequirements(
1273-
ASTContext &C, ArrayRef<Requirement> allRequirements) {
1274-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
1273+
ASTContext &C,
1274+
ArrayRef<Requirement> allRequirements,
1275+
llvm::SmallPtrSet<ProtocolDecl *, 2> &into) {
12751276
auto DA = C.getDistributedActorDecl();
12761277
auto daSerializationReqAssocType =
12771278
DA->getAssociatedType(C.Id_SerializationRequirement);
@@ -1293,7 +1294,7 @@ swift::extractDistributedSerializationRequirements(
12931294
auto requirementProto = req.getSecondType();
12941295
if (auto proto = dyn_cast_or_null<ProtocolDecl>(
12951296
requirementProto->getAnyNominal())) {
1296-
serializationReqs.insert(proto);
1297+
into.insert(proto);
12971298
} else {
12981299
auto serialReqType = requirementProto->castTo<ExistentialType>()
12991300
->getConstraintType()
@@ -1302,14 +1303,12 @@ swift::extractDistributedSerializationRequirements(
13021303
flattenDistributedSerializationTypeToRequiredProtocols(
13031304
serialReqType);
13041305
for (auto p : flattenedRequirements) {
1305-
serializationReqs.insert(p);
1306+
into.insert(p);
13061307
}
13071308
}
13081309
}
13091310
}
13101311
}
1311-
1312-
return serializationReqs;
13131312
}
13141313

13151314
/******************************************************************************/

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator,
841841
if (!distributedTarget->isDistributed())
842842
return nullptr;
843843
}
844-
845844
assert(distributedTarget);
846845

846+
// This evaluation type-check by now was already computed and cached;
847+
// We need to check in order to avoid emitting a THUNK for a distributed func
848+
// which had errors; as the thunk then may also cause un-addressable issues and confusion.
849+
if (swift::checkDistributedFunction(distributedTarget)) {
850+
return nullptr;
851+
}
852+
847853
auto &C = distributedTarget->getASTContext();
848854

849855
if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) {

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2067,7 +2067,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) {
20672067
return (prop &&
20682068
prop->isFinal() &&
20692069
isa<ClassDecl>(prop->getDeclContext()) &&
2070-
cast<ClassDecl>(prop->getDeclContext())->isActor() &&
2070+
cast<ClassDecl>(prop->getDeclContext())->isAnyActor() &&
20712071
!prop->isStatic() &&
20722072
prop->getName() == ctx.Id_unownedExecutor &&
20732073
prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl());

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,22 @@ bool CheckDistributedFunctionRequest::evaluate(
498498
// SerializationRequirement
499499
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements;
500500
if (auto extension = dyn_cast<ExtensionDecl>(DC)) {
501-
serializationRequirements = extractDistributedSerializationRequirements(
502-
C, extension->getGenericRequirements());
501+
auto actorOrProtocol = extension->getExtendedNominal();
502+
if (auto actor = dyn_cast<ClassDecl>(actorOrProtocol)) {
503+
assert(actor->isAnyActor());
504+
serializationRequirements = getDistributedSerializationRequirementProtocols(
505+
getDistributedActorSystemType(actor)->getAnyNominal(),
506+
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
507+
} else if (auto protocol = dyn_cast<ProtocolDecl>(actorOrProtocol)) {
508+
extractDistributedSerializationRequirements(
509+
C, protocol->getGenericRequirements(),
510+
/*into=*/serializationRequirements);
511+
extractDistributedSerializationRequirements(
512+
C, extension->getGenericRequirements(),
513+
/*into=*/serializationRequirements);
514+
} else {
515+
// ignore
516+
}
503517
} else if (auto actor = dyn_cast<ClassDecl>(DC)) {
504518
serializationRequirements = getDistributedSerializationRequirementProtocols(
505519
getDistributedActorSystemType(actor)->getAnyNominal(),
@@ -546,6 +560,7 @@ bool CheckDistributedFunctionRequest::evaluate(
546560
if (auto paramNominalTy = paramTy->getAnyNominal()) {
547561
addCodableFixIt(paramNominalTy, diag);
548562
} // else, no nominal type to suggest the fixit for, e.g. a closure
563+
549564
return true;
550565
}
551566
}
@@ -795,11 +810,11 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal
795810
(void)nominal->getDistributedActorIDProperty();
796811
}
797812

798-
void TypeChecker::checkDistributedFunc(FuncDecl *func) {
813+
bool TypeChecker::checkDistributedFunc(FuncDecl *func) {
799814
if (!func->isDistributed())
800-
return;
815+
return false;
801816

802-
swift::checkDistributedFunction(func);
817+
return swift::checkDistributedFunction(func);
803818
}
804819

805820
llvm::SmallPtrSet<ProtocolDecl *, 2>

lib/Sema/TypeCheckStmt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,6 +2778,14 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &eval,
27782778
// So, build out the body now.
27792779
ASTScope::expandFunctionBody(AFD);
27802780

2781+
if (AFD->isDistributedThunk()) {
2782+
if (auto func = dyn_cast<FuncDecl>(AFD)) {
2783+
if (TypeChecker::checkDistributedFunc(func)) {
2784+
return errorBody();
2785+
}
2786+
}
2787+
}
2788+
27812789
// Type check the function body if needed.
27822790
bool hadError = false;
27832791
if (!alreadyTypeChecked) {

lib/Sema/TypeChecker.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,9 @@ diagnosePotentialUnavailability(SourceRange ReferenceRange,
11321132
void checkDistributedActor(SourceFile *SF, NominalTypeDecl *decl);
11331133

11341134
/// Type check a single 'distributed func' declaration.
1135-
void checkDistributedFunc(FuncDecl *func);
1135+
///
1136+
/// Returns `true` if there was an error.
1137+
bool checkDistributedFunc(FuncDecl *func);
11361138

11371139
bool checkAvailability(SourceRange ReferenceRange,
11381140
AvailabilityContext Availability,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/Inputs/FakeDistributedActorSystems.swift
3+
// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/Inputs/FakeDistributedActorSystems.swift 2> %t/output.txt || echo 'failed expectedly'
4+
// RUN: %FileCheck %s < %t/output.txt
5+
6+
// REQUIRES: concurrency
7+
// REQUIRES: distributed
8+
9+
// rdar://76038845
10+
// UNSUPPORTED: use_os_stdlib
11+
// UNSUPPORTED: back_deployment_runtime
12+
13+
import Distributed
14+
15+
// Notes:
16+
// This test specifically is not just a -typecheck -verify test but attempts to generate the whole module.
17+
// This is because we may be emitting errors but otherwise still attempt to emit a thunk for an "error-ed"
18+
// distributed function, which would then crash in later phases of compilation when we try to get types
19+
// of the `func` the THUNK is based on.
20+
21+
typealias DefaultDistributedActorSystem = LocalTestingDistributedActorSystem
22+
23+
distributed actor Service {
24+
}
25+
26+
extension Service {
27+
distributed func boombox(_ id: Box) async throws {}
28+
// CHECK: parameter '' of type 'Box' in distributed instance method does not conform to serialization requirement 'Codable'
29+
30+
distributed func boxIt() async throws -> Box { fatalError() }
31+
// CHECK: result type 'Box' of distributed instance method 'boxIt' does not conform to serialization requirement 'Codable'
32+
}
33+
34+
public enum Box: Hashable { case boom }
35+
36+
@main struct Main {
37+
static func main() async {
38+
try? await Service(actorSystem: .init()).boombox(Box.boom)
39+
}
40+
}
41+

0 commit comments

Comments
 (0)