Skip to content

Commit 86f4820

Browse files
committed
[Distributed] Resolve mangling issues with @resolvable
1 parent 807b543 commit 86f4820

17 files changed

+423
-80
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ PROTOCOL(Differentiable)
119119

120120
// Distributed Actors
121121
PROTOCOL(DistributedActor)
122+
PROTOCOL_(DistributedActorStub)
122123
PROTOCOL(DistributedActorSystem)
123124
PROTOCOL(DistributedTargetInvocationEncoder)
124125
PROTOCOL(DistributedTargetInvocationDecoder)

lib/AST/ASTMangler.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4452,12 +4452,32 @@ void ASTMangler::appendDistributedThunk(
44524452
return nullptr;
44534453
};
44544454

4455-
if (auto *P = referenceInProtocolContextOrRequirement()) {
4456-
appendContext(P->getDeclContext(), base,
4457-
thunk->getAlternateModuleName());
4458-
appendIdentifier(Twine("$", P->getNameStr()).str());
4459-
appendOperator("C"); // necessary for roundtrip, though we don't use it
4455+
// Determine if we should mangle with a $Target substitute decl context,
4456+
// this matters for @Resolvable calls / protocol calls where the caller
4457+
// does not know the type of the recipient distributed actor, and we use the
4458+
// $Target type as substitute to then generically invoke it on the type of the
4459+
// recipient, whichever 'protocol Type'-conforming type it will be.
4460+
NominalTypeDecl *stubActorDecl = nullptr;
4461+
if (auto P = referenceInProtocolContextOrRequirement()) {
4462+
auto &C = thunk->getASTContext();
4463+
auto M = thunk->getModuleContext();
4464+
4465+
SmallVector<ValueDecl *, 1> stubClassLookupResults;
4466+
C.lookupInModule(M, ("$" + P->getNameStr()).str(), stubClassLookupResults);
4467+
4468+
if (stubClassLookupResults.size() > 0) {
4469+
stubActorDecl =
4470+
dyn_cast_or_null<NominalTypeDecl>(stubClassLookupResults.front());
4471+
}
4472+
}
4473+
4474+
if (stubActorDecl) {
4475+
// Effectively mangle the thunk as if it was declared on the $StubTarget
4476+
// type, rather than on a `protocol Target`.
4477+
appendContext(stubActorDecl, base, thunk->getAlternateModuleName());
44604478
} else {
4479+
// There's no need to replace the context, this is a normal concrete type
4480+
// remote call identifier.
44614481
appendContextOf(thunk, base);
44624482
}
44634483

lib/IRGen/GenDistributed.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ void IRGenModule::emitDistributedTargetAccessor(ThunkOrRequirement target) {
379379
IRGenMangler mangler;
380380

381381
addAccessibleFunction(AccessibleFunction::forDistributed(
382-
mangler.mangleDistributedThunkRecord(targetDecl),
383-
mangler.mangleDistributedThunk(targetDecl),
382+
/*recordName=*/mangler.mangleDistributedThunkRecord(targetDecl),
383+
/*accessorName=*/mangler.mangleDistributedThunk(targetDecl),
384384
accessor.getTargetType(),
385385
getAddrOfAsyncFunctionPointer(accessorRef)));
386386
}

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6958,6 +6958,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
69586958
case KnownProtocolKind::Identifiable:
69596959
case KnownProtocolKind::Actor:
69606960
case KnownProtocolKind::DistributedActor:
6961+
case KnownProtocolKind::DistributedActorStub:
69616962
case KnownProtocolKind::DistributedActorSystem:
69626963
case KnownProtocolKind::DistributedTargetInvocationEncoder:
69636964
case KnownProtocolKind::DistributedTargetInvocationDecoder:

lib/Macros/Sources/SwiftMacros/DistributedResolvableMacro.swift

Lines changed: 90 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ extension DistributedResolvableMacro {
130130
return []
131131
}
132132

133-
var isGenericStub = false
133+
var isGenericOverActorSystem = false
134134
var specificActorSystemRequirement: TypeSyntax?
135135

136136
let accessModifiers = proto.accessControlModifiers
@@ -140,15 +140,15 @@ extension DistributedResolvableMacro {
140140
case .conformanceRequirement(let conformanceReq)
141141
where conformanceReq.leftType.isActorSystem:
142142
specificActorSystemRequirement = conformanceReq.rightType.trimmed
143-
isGenericStub = true
143+
isGenericOverActorSystem = true
144144

145145
case .sameTypeRequirement(let sameTypeReq):
146146
switch sameTypeReq.leftType {
147147
case .type(let type) where type.isActorSystem:
148148
switch sameTypeReq.rightType.trimmed {
149149
case .type(let rightType):
150150
specificActorSystemRequirement = rightType
151-
isGenericStub = false
151+
isGenericOverActorSystem = false
152152

153153
case .expr:
154154
throw DiagnosticsError(
@@ -167,41 +167,78 @@ extension DistributedResolvableMacro {
167167
}
168168
}
169169

170-
if isGenericStub, let specificActorSystemRequirement {
171-
return [
172-
"""
173-
\(proto.modifiers) distributed actor $\(proto.name.trimmed)<ActorSystem>: \(proto.name.trimmed),
174-
Distributed._DistributedActorStub
175-
where ActorSystem: \(specificActorSystemRequirement)
176-
{ }
177-
"""
178-
]
179-
} else if let specificActorSystemRequirement {
180-
return [
181-
"""
182-
\(proto.modifiers) distributed actor $\(proto.name.trimmed): \(proto.name.trimmed),
183-
Distributed._DistributedActorStub
184-
{
185-
\(typealiasActorSystem(access: accessModifiers, proto, specificActorSystemRequirement))
186-
}
187-
"""
188-
]
189-
} else {
190-
// there may be no `where` clause specifying an actor system,
191-
// but perhaps there is a typealias (or extension with a typealias),
192-
// specifying a concrete actor system so we let this synthesize
193-
// an empty `$Greeter` -- this may fail, or succeed depending on
194-
// surrounding code using a default distributed actor system,
195-
// or extensions providing it.
196-
return [
197-
"""
198-
\(proto.modifiers) distributed actor $\(proto.name.trimmed): \(proto.name.trimmed),
199-
Distributed._DistributedActorStub
200-
{
170+
var primaryAssociatedTypes: [PrimaryAssociatedTypeSyntax] = []
171+
if let primaryTypes = proto.primaryAssociatedTypeClause?.primaryAssociatedTypes {
172+
primaryAssociatedTypes.append(contentsOf: primaryTypes)
173+
}
174+
175+
// The $Stub is always generic over the actor system: $Stub<ActorSystem>
176+
var primaryTypeParams: [String] = primaryAssociatedTypes.map {
177+
$0.as(PrimaryAssociatedTypeSyntax.self)!.name.trimmed.text
178+
}
179+
180+
// Don't duplicate the ActorSystem type parameter if it already was declared
181+
// on the protocol as a primary associated type;
182+
// otherwise, add it as first primary associated type.
183+
let actorSystemTypeParam: [String] =
184+
if primaryTypeParams.contains("ActorSystem") {
185+
[]
186+
} else if isGenericOverActorSystem {
187+
["ActorSystem"]
188+
} else {
189+
[]
190+
}
191+
192+
// Prepend the actor system type parameter, as we want it to be the first one
193+
primaryTypeParams = actorSystemTypeParam + primaryTypeParams
194+
let typeParamsClause =
195+
primaryTypeParams.isEmpty ? "" : "<" + primaryTypeParams.joined(separator: ", ") + ">"
196+
197+
var whereClause: String = ""
198+
do {
199+
let associatedTypeDecls = proto.associatedTypeDecls
200+
var typeParamConstraints: [String] = []
201+
for typeParamName in primaryTypeParams {
202+
if let decl = associatedTypeDecls[typeParamName] {
203+
if let inheritanceClause = decl.inheritanceClause {
204+
typeParamConstraints.append("\(typeParamName)\(inheritanceClause)")
205+
}
201206
}
202-
"""
203-
]
207+
}
208+
209+
if isGenericOverActorSystem, let specificActorSystemRequirement {
210+
typeParamConstraints = ["ActorSystem: \(specificActorSystemRequirement)"] + typeParamConstraints
211+
}
212+
213+
if !typeParamConstraints.isEmpty {
214+
whereClause += "\n where " + typeParamConstraints.joined(separator: ",\n ")
215+
}
204216
}
217+
218+
let stubActorBody: String =
219+
if isGenericOverActorSystem {
220+
// there may be no `where` clause specifying an actor system,
221+
// but perhaps there is a typealias (or extension with a typealias),
222+
// specifying a concrete actor system so we let this synthesize
223+
// an empty `$Greeter` -- this may fail, or succeed depending on
224+
// surrounding code using a default distributed actor system,
225+
// or extensions providing it.
226+
""
227+
} else if let specificActorSystemRequirement {
228+
"\(typealiasActorSystem(access: accessModifiers, proto, specificActorSystemRequirement))"
229+
} else {
230+
""
231+
}
232+
233+
return [
234+
"""
235+
\(proto.modifiers) distributed actor $\(proto.name.trimmed)\(raw: typeParamsClause): \(proto.name.trimmed),
236+
Distributed._DistributedActorStub \(raw: whereClause)
237+
{
238+
\(raw: stubActorBody)
239+
}
240+
"""
241+
]
205242
}
206243

207244
private static func typealiasActorSystem(access: DeclModifierListSyntax,
@@ -253,6 +290,23 @@ extension DeclModifierSyntax {
253290
}
254291
}
255292

293+
extension ProtocolDeclSyntax {
294+
var associatedTypeDecls: [String: AssociatedTypeDeclSyntax] {
295+
let visitor = AssociatedTypeDeclVisitor(viewMode: .all)
296+
visitor.walk(self)
297+
return visitor.associatedTypeDecls
298+
}
299+
300+
final class AssociatedTypeDeclVisitor: SyntaxVisitor {
301+
var associatedTypeDecls: [String: AssociatedTypeDeclSyntax] = [:]
302+
303+
override func visit(_ node: AssociatedTypeDeclSyntax) -> SyntaxVisitorContinueKind {
304+
associatedTypeDecls[node.name.text] = node
305+
return .skipChildren
306+
}
307+
}
308+
}
309+
256310
// ===== -----------------------------------------------------------------------
257311
// MARK: @Distributed.Resolvable macro errors
258312

stdlib/public/runtime/AccessibleFunction.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "swift/Demangling/Demangler.h"
2121
#include "swift/Runtime/AccessibleFunction.h"
2222
#include "swift/Runtime/Concurrent.h"
23+
#include "swift/Runtime/EnvironmentVariables.h"
2324
#include "swift/Runtime/Metadata.h"
25+
#include "swift/Threading/Once.h"
2426
#include "Tracing.h"
2527

2628
#include <cstdint>
@@ -98,6 +100,29 @@ static Lazy<AccessibleFunctionsState> Functions;
98100

99101
} // end anonymous namespace
100102

103+
LLVM_ATTRIBUTE_UNUSED
104+
static void _dumpAccessibleFunctionRecords(void *context) {
105+
auto &S = Functions.get();
106+
107+
fprintf(stderr, "==== Accessible Function Records ====\n");
108+
int count = 0;
109+
for (const auto &section : S.SectionsToScan.snapshot()) {
110+
for (auto &record : section) {
111+
auto recordName =
112+
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
113+
auto demangledRecordName =
114+
swift::Demangle::demangleSymbolAsString(recordName);
115+
fprintf(stderr, "Record name: %s\n", recordName.data());
116+
fprintf(stderr, " Demangled: %s\n", demangledRecordName.c_str());
117+
fprintf(stderr, " Function Ptr: %p\n", record.Function.get());
118+
fprintf(stderr, " Flags.IsDistributed: %d\n", record.Flags.isDistributed());
119+
++count;
120+
}
121+
}
122+
fprintf(stderr, "Record count: %d\n", count);
123+
fprintf(stderr, "==== End of Accessible Function Records ====\n");
124+
}
125+
101126
static void _registerAccessibleFunctions(AccessibleFunctionsState &C,
102127
AccessibleFunctionsSection section) {
103128
C.SectionsToScan.push_back(section);
@@ -119,27 +144,6 @@ void swift::addImageAccessibleFunctionsBlockCallback(
119144
addImageAccessibleFunctionsBlockCallbackUnsafe(baseAddress, functions, size);
120145
}
121146

122-
// TODO(distributed): expose dumping records via a flag
123-
LLVM_ATTRIBUTE_UNUSED
124-
static void _dumpAccessibleFunctionRecords() {
125-
auto &S = Functions.get();
126-
127-
fprintf(stderr, "==== Accessible Function Records ====\n");
128-
int count = 0;
129-
for (const auto &section : S.SectionsToScan.snapshot()) {
130-
for (auto &record : section) {
131-
auto recordName =
132-
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
133-
fprintf(stderr, "Record name: %s\n", recordName.data());
134-
fprintf(stderr, " Function Ptr: %p\n", record.Function.get());
135-
fprintf(stderr, " Flags.IsDistributed: %d\n", record.Flags.isDistributed());
136-
++count;
137-
}
138-
}
139-
fprintf(stderr, "Record count: %d\n", count);
140-
fprintf(stderr, "==== End of Accessible Function Records ====\n");
141-
}
142-
143147
static const AccessibleFunctionRecord *
144148
_searchForFunctionRecord(AccessibleFunctionsState &S, llvm::StringRef name) {
145149
auto traceState = runtime::trace::accessible_function_scan_begin(name);
@@ -148,8 +152,9 @@ _searchForFunctionRecord(AccessibleFunctionsState &S, llvm::StringRef name) {
148152
for (auto &record : section) {
149153
auto recordName =
150154
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
151-
if (recordName == name)
155+
if (recordName == name) {
152156
return traceState.end(&record);
157+
}
153158
}
154159
}
155160
return nullptr;
@@ -160,13 +165,19 @@ const AccessibleFunctionRecord *
160165
swift::runtime::swift_findAccessibleFunction(const char *targetNameStart,
161166
size_t targetNameLength) {
162167
auto &S = Functions.get();
163-
164168
llvm::StringRef name{targetNameStart, targetNameLength};
169+
170+
if (swift::runtime::environment::SWIFT_DUMP_ACCESSIBLE_FUNCTIONS()) {
171+
static swift::once_t dumpAccessibleFunctionsToken;
172+
swift::once(dumpAccessibleFunctionsToken, _dumpAccessibleFunctionRecords, nullptr);
173+
}
174+
165175
// Look for an existing entry.
166176
{
167177
auto snapshot = S.Cache.snapshot();
168-
if (auto E = snapshot.find(name))
178+
if (auto E = snapshot.find(name)) {
169179
return E->getRecord();
180+
}
170181
}
171182

172183
// If entry doesn't exist (either record doesn't exist, hasn't been loaded, or

stdlib/public/runtime/EnvironmentVariables.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,10 @@ VARIABLE(SWIFT_IS_CURRENT_EXECUTOR_LEGACY_MODE_OVERRIDE, string, "",
128128
" 'legacy' (Legacy behavior), "
129129
" 'swift6' (Swift 6.0+ behavior)")
130130

131+
VARIABLE(SWIFT_DUMP_ACCESSIBLE_FUNCTIONS, string, "",
132+
"Dump a listing of all 'AccessibleFunctionRecord's upon first access. "
133+
"These are used to obtain function pointers from accessible function "
134+
"record names, e.g. by the Distributed runtime to invoke distributed "
135+
"functions.")
136+
131137
#undef VARIABLE

0 commit comments

Comments
 (0)