Skip to content

Commit ed5007f

Browse files
authored
Merge pull request #77584 from ktoso/wip-check-array-calls
2 parents 66321f7 + 746720c commit ed5007f

17 files changed

+434
-74
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ PROTOCOL(Differentiable)
119119

120120
// Distributed Actors
121121
PROTOCOL(DistributedActor)
122+
PROTOCOL_(DistributedActorStub)
122123
PROTOCOL(DistributedActorSystem)
123124
PROTOCOL(DistributedTargetInvocationEncoder)
124125
PROTOCOL(DistributedTargetInvocationDecoder)

lib/AST/ASTMangler.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4493,12 +4493,33 @@ void ASTMangler::appendDistributedThunk(
44934493
return nullptr;
44944494
};
44954495

4496-
if (auto *P = referenceInProtocolContextOrRequirement()) {
4497-
appendContext(P->getDeclContext(), base,
4498-
thunk->getAlternateModuleName());
4499-
appendIdentifier(Twine("$", P->getNameStr()).str());
4500-
appendOperator("C"); // necessary for roundtrip, though we don't use it
4496+
// Determine if we should mangle with a $Target substitute decl context,
4497+
// this matters for @Resolvable calls / protocol calls where the caller
4498+
// does not know the type of the recipient distributed actor, and we use the
4499+
// $Target type as substitute to then generically invoke it on the type of the
4500+
// recipient, whichever 'protocol Type'-conforming type it will be.
4501+
NominalTypeDecl *stubActorDecl = nullptr;
4502+
if (auto P = referenceInProtocolContextOrRequirement()) {
4503+
auto &C = thunk->getASTContext();
4504+
auto M = thunk->getModuleContext();
4505+
4506+
SmallVector<ValueDecl *, 1> stubClassLookupResults;
4507+
C.lookupInModule(M, llvm::Twine("$", P->getNameStr()).str(), stubClassLookupResults);
4508+
4509+
assert(stubClassLookupResults.size() <= 1 && "Found multiple distributed stub types!");
4510+
if (stubClassLookupResults.size() > 0) {
4511+
stubActorDecl =
4512+
dyn_cast_or_null<NominalTypeDecl>(stubClassLookupResults.front());
4513+
}
4514+
}
4515+
4516+
if (stubActorDecl) {
4517+
// Effectively mangle the thunk as if it was declared on the $StubTarget
4518+
// type, rather than on a `protocol Target`.
4519+
appendContext(stubActorDecl, base, thunk->getAlternateModuleName());
45014520
} else {
4521+
// There's no need to replace the context, this is a normal concrete type
4522+
// remote call identifier.
45024523
appendContextOf(thunk, base);
45034524
}
45044525

lib/IRGen/GenDistributed.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ void IRGenModule::emitDistributedTargetAccessor(ThunkOrRequirement target) {
379379
IRGenMangler mangler(Context);
380380

381381
addAccessibleFunction(AccessibleFunction::forDistributed(
382-
mangler.mangleDistributedThunkRecord(targetDecl),
383-
mangler.mangleDistributedThunk(targetDecl),
382+
/*recordName=*/mangler.mangleDistributedThunkRecord(targetDecl),
383+
/*accessorName=*/mangler.mangleDistributedThunk(targetDecl),
384384
accessor.getTargetType(),
385385
getAddrOfAsyncFunctionPointer(accessorRef)));
386386
}

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6956,6 +6956,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
69566956
case KnownProtocolKind::Identifiable:
69576957
case KnownProtocolKind::Actor:
69586958
case KnownProtocolKind::DistributedActor:
6959+
case KnownProtocolKind::DistributedActorStub:
69596960
case KnownProtocolKind::DistributedActorSystem:
69606961
case KnownProtocolKind::DistributedTargetInvocationEncoder:
69616962
case KnownProtocolKind::DistributedTargetInvocationDecoder:

lib/Macros/Sources/SwiftMacros/DistributedResolvableMacro.swift

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ extension DistributedResolvableMacro {
130130
return []
131131
}
132132

133-
var isGenericStub = false
133+
var isGenericOverActorSystem = false
134134
var specificActorSystemRequirement: TypeSyntax?
135135

136136
let accessModifiers = proto.accessControlModifiers
@@ -140,15 +140,15 @@ extension DistributedResolvableMacro {
140140
case .conformanceRequirement(let conformanceReq)
141141
where conformanceReq.leftType.isActorSystem:
142142
specificActorSystemRequirement = conformanceReq.rightType.trimmed
143-
isGenericStub = true
143+
isGenericOverActorSystem = true
144144

145145
case .sameTypeRequirement(let sameTypeReq):
146146
switch sameTypeReq.leftType {
147147
case .type(let type) where type.isActorSystem:
148148
switch sameTypeReq.rightType.trimmed {
149149
case .type(let rightType):
150150
specificActorSystemRequirement = rightType
151-
isGenericStub = false
151+
isGenericOverActorSystem = false
152152

153153
case .expr:
154154
throw DiagnosticsError(
@@ -167,41 +167,78 @@ extension DistributedResolvableMacro {
167167
}
168168
}
169169

170-
if isGenericStub, let specificActorSystemRequirement {
171-
return [
172-
"""
173-
\(proto.modifiers) distributed actor $\(proto.name.trimmed)<ActorSystem>: \(proto.name.trimmed),
174-
Distributed._DistributedActorStub
175-
where ActorSystem: \(specificActorSystemRequirement)
176-
{ }
177-
"""
178-
]
179-
} else if let specificActorSystemRequirement {
180-
return [
181-
"""
182-
\(proto.modifiers) distributed actor $\(proto.name.trimmed): \(proto.name.trimmed),
183-
Distributed._DistributedActorStub
184-
{
185-
\(typealiasActorSystem(access: accessModifiers, proto, specificActorSystemRequirement))
186-
}
187-
"""
188-
]
170+
var primaryAssociatedTypes: [PrimaryAssociatedTypeSyntax] = []
171+
if let primaryTypes = proto.primaryAssociatedTypeClause?.primaryAssociatedTypes {
172+
primaryAssociatedTypes.append(contentsOf: primaryTypes)
173+
}
174+
175+
// The $Stub is always generic over the actor system: $Stub<ActorSystem>
176+
var primaryTypeParams: [String] = primaryAssociatedTypes.map {
177+
$0.as(PrimaryAssociatedTypeSyntax.self)!.name.trimmed.text
178+
}
179+
180+
// Don't duplicate the ActorSystem type parameter if it already was declared
181+
// on the protocol as a primary associated type;
182+
// otherwise, add it as first primary associated type.
183+
let actorSystemTypeParam: [String]
184+
if primaryTypeParams.contains("ActorSystem") {
185+
actorSystemTypeParam = []
186+
} else if isGenericOverActorSystem {
187+
actorSystemTypeParam = ["ActorSystem"]
189188
} else {
189+
actorSystemTypeParam = []
190+
}
191+
192+
// Prepend the actor system type parameter, as we want it to be the first one
193+
primaryTypeParams = actorSystemTypeParam + primaryTypeParams
194+
let typeParamsClause =
195+
primaryTypeParams.isEmpty ? "" : "<" + primaryTypeParams.joined(separator: ", ") + ">"
196+
197+
var whereClause: String = ""
198+
do {
199+
let associatedTypeDecls = proto.associatedTypeDecls
200+
var typeParamConstraints: [String] = []
201+
for typeParamName in primaryTypeParams {
202+
if let decl = associatedTypeDecls[typeParamName] {
203+
if let inheritanceClause = decl.inheritanceClause {
204+
typeParamConstraints.append("\(typeParamName)\(inheritanceClause)")
205+
}
206+
}
207+
}
208+
209+
if isGenericOverActorSystem, let specificActorSystemRequirement {
210+
typeParamConstraints = ["ActorSystem: \(specificActorSystemRequirement)"] + typeParamConstraints
211+
}
212+
213+
if !typeParamConstraints.isEmpty {
214+
whereClause += "\n where " + typeParamConstraints.joined(separator: ",\n ")
215+
}
216+
}
217+
218+
let stubActorBody: String
219+
if isGenericOverActorSystem {
190220
// there may be no `where` clause specifying an actor system,
191221
// but perhaps there is a typealias (or extension with a typealias),
192222
// specifying a concrete actor system so we let this synthesize
193223
// an empty `$Greeter` -- this may fail, or succeed depending on
194224
// surrounding code using a default distributed actor system,
195225
// or extensions providing it.
196-
return [
197-
"""
198-
\(proto.modifiers) distributed actor $\(proto.name.trimmed): \(proto.name.trimmed),
199-
Distributed._DistributedActorStub
200-
{
201-
}
202-
"""
203-
]
226+
stubActorBody = ""
227+
} else if let specificActorSystemRequirement {
228+
stubActorBody = "\(typealiasActorSystem(access: accessModifiers, proto, specificActorSystemRequirement))"
229+
} else {
230+
stubActorBody = ""
204231
}
232+
233+
return [
234+
"""
235+
\(proto.modifiers) distributed actor $\(proto.name.trimmed)\(raw: typeParamsClause): \(proto.name.trimmed),
236+
Distributed._DistributedActorStub \(raw: whereClause)
237+
{
238+
\(raw: stubActorBody)
239+
}
240+
"""
241+
]
205242
}
206243

207244
private static func typealiasActorSystem(access: DeclModifierListSyntax,
@@ -253,6 +290,23 @@ extension DeclModifierSyntax {
253290
}
254291
}
255292

293+
extension ProtocolDeclSyntax {
294+
var associatedTypeDecls: [String: AssociatedTypeDeclSyntax] {
295+
let visitor = AssociatedTypeDeclVisitor(viewMode: .all)
296+
visitor.walk(self)
297+
return visitor.associatedTypeDecls
298+
}
299+
300+
final class AssociatedTypeDeclVisitor: SyntaxVisitor {
301+
var associatedTypeDecls: [String: AssociatedTypeDeclSyntax] = [:]
302+
303+
override func visit(_ node: AssociatedTypeDeclSyntax) -> SyntaxVisitorContinueKind {
304+
associatedTypeDecls[node.name.text] = node
305+
return .skipChildren
306+
}
307+
}
308+
}
309+
256310
// ===== -----------------------------------------------------------------------
257311
// MARK: @Distributed.Resolvable macro errors
258312

stdlib/public/runtime/AccessibleFunction.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "swift/Demangling/Demangler.h"
2121
#include "swift/Runtime/AccessibleFunction.h"
2222
#include "swift/Runtime/Concurrent.h"
23+
#include "swift/Runtime/EnvironmentVariables.h"
2324
#include "swift/Runtime/Metadata.h"
25+
#include "swift/Threading/Once.h"
2426
#include "Tracing.h"
2527

2628
#include <cstdint>
@@ -98,6 +100,29 @@ static Lazy<AccessibleFunctionsState> Functions;
98100

99101
} // end anonymous namespace
100102

103+
LLVM_ATTRIBUTE_UNUSED
104+
static void _dumpAccessibleFunctionRecords(void *context) {
105+
auto &S = Functions.get();
106+
107+
fprintf(stderr, "==== Accessible Function Records ====\n");
108+
int count = 0;
109+
for (const auto &section : S.SectionsToScan.snapshot()) {
110+
for (auto &record : section) {
111+
auto recordName =
112+
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
113+
auto demangledRecordName =
114+
swift::Demangle::demangleSymbolAsString(recordName);
115+
fprintf(stderr, "Record name: %s\n", recordName.data());
116+
fprintf(stderr, " Demangled: %s\n", demangledRecordName.c_str());
117+
fprintf(stderr, " Function Ptr: %p\n", record.Function.get());
118+
fprintf(stderr, " Flags.IsDistributed: %d\n", record.Flags.isDistributed());
119+
++count;
120+
}
121+
}
122+
fprintf(stderr, "Record count: %d\n", count);
123+
fprintf(stderr, "==== End of Accessible Function Records ====\n");
124+
}
125+
101126
static void _registerAccessibleFunctions(AccessibleFunctionsState &C,
102127
AccessibleFunctionsSection section) {
103128
C.SectionsToScan.push_back(section);
@@ -119,27 +144,6 @@ void swift::addImageAccessibleFunctionsBlockCallback(
119144
addImageAccessibleFunctionsBlockCallbackUnsafe(baseAddress, functions, size);
120145
}
121146

122-
// TODO(distributed): expose dumping records via a flag
123-
LLVM_ATTRIBUTE_UNUSED
124-
static void _dumpAccessibleFunctionRecords() {
125-
auto &S = Functions.get();
126-
127-
fprintf(stderr, "==== Accessible Function Records ====\n");
128-
int count = 0;
129-
for (const auto &section : S.SectionsToScan.snapshot()) {
130-
for (auto &record : section) {
131-
auto recordName =
132-
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
133-
fprintf(stderr, "Record name: %s\n", recordName.data());
134-
fprintf(stderr, " Function Ptr: %p\n", record.Function.get());
135-
fprintf(stderr, " Flags.IsDistributed: %d\n", record.Flags.isDistributed());
136-
++count;
137-
}
138-
}
139-
fprintf(stderr, "Record count: %d\n", count);
140-
fprintf(stderr, "==== End of Accessible Function Records ====\n");
141-
}
142-
143147
static const AccessibleFunctionRecord *
144148
_searchForFunctionRecord(AccessibleFunctionsState &S, llvm::StringRef name) {
145149
auto traceState = runtime::trace::accessible_function_scan_begin(name);
@@ -148,8 +152,9 @@ _searchForFunctionRecord(AccessibleFunctionsState &S, llvm::StringRef name) {
148152
for (auto &record : section) {
149153
auto recordName =
150154
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
151-
if (recordName == name)
155+
if (recordName == name) {
152156
return traceState.end(&record);
157+
}
153158
}
154159
}
155160
return nullptr;
@@ -160,13 +165,19 @@ const AccessibleFunctionRecord *
160165
swift::runtime::swift_findAccessibleFunction(const char *targetNameStart,
161166
size_t targetNameLength) {
162167
auto &S = Functions.get();
163-
164168
llvm::StringRef name{targetNameStart, targetNameLength};
169+
170+
if (swift::runtime::environment::SWIFT_DUMP_ACCESSIBLE_FUNCTIONS()) {
171+
static swift::once_t dumpAccessibleFunctionsToken;
172+
swift::once(dumpAccessibleFunctionsToken, _dumpAccessibleFunctionRecords, nullptr);
173+
}
174+
165175
// Look for an existing entry.
166176
{
167177
auto snapshot = S.Cache.snapshot();
168-
if (auto E = snapshot.find(name))
178+
if (auto E = snapshot.find(name)) {
169179
return E->getRecord();
180+
}
170181
}
171182

172183
// If entry doesn't exist (either record doesn't exist, hasn't been loaded, or

stdlib/public/runtime/EnvironmentVariables.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,10 @@ VARIABLE(SWIFT_IS_CURRENT_EXECUTOR_LEGACY_MODE_OVERRIDE, string, "",
128128
" 'legacy' (Legacy behavior), "
129129
" 'swift6' (Swift 6.0+ behavior)")
130130

131+
VARIABLE(SWIFT_DUMP_ACCESSIBLE_FUNCTIONS, string, "",
132+
"Dump a listing of all 'AccessibleFunctionRecord's upon first access. "
133+
"These are used to obtain function pointers from accessible function "
134+
"record names, e.g. by the Distributed runtime to invoke distributed "
135+
"functions.")
136+
131137
#undef VARIABLE

0 commit comments

Comments
 (0)