Skip to content

[Distributed] Protocol mangling and primary associated type fixes #77584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ PROTOCOL(Differentiable)

// Distributed Actors
PROTOCOL(DistributedActor)
PROTOCOL_(DistributedActorStub)
PROTOCOL(DistributedActorSystem)
PROTOCOL(DistributedTargetInvocationEncoder)
PROTOCOL(DistributedTargetInvocationDecoder)
Expand Down
31 changes: 26 additions & 5 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4453,12 +4453,33 @@ void ASTMangler::appendDistributedThunk(
return nullptr;
};

if (auto *P = referenceInProtocolContextOrRequirement()) {
appendContext(P->getDeclContext(), base,
thunk->getAlternateModuleName());
appendIdentifier(Twine("$", P->getNameStr()).str());
appendOperator("C"); // necessary for roundtrip, though we don't use it
// Determine if we should mangle with a $Target substitute decl context,
// this matters for @Resolvable calls / protocol calls where the caller
// does not know the type of the recipient distributed actor, and we use the
// $Target type as substitute to then generically invoke it on the type of the
// recipient, whichever 'protocol Type'-conforming type it will be.
NominalTypeDecl *stubActorDecl = nullptr;
if (auto P = referenceInProtocolContextOrRequirement()) {
auto &C = thunk->getASTContext();
auto M = thunk->getModuleContext();

SmallVector<ValueDecl *, 1> stubClassLookupResults;
C.lookupInModule(M, llvm::Twine("$", P->getNameStr()).str(), stubClassLookupResults);

assert(stubClassLookupResults.size() <= 1 && "Found multiple distributed stub types!");
if (stubClassLookupResults.size() > 0) {
stubActorDecl =
dyn_cast_or_null<NominalTypeDecl>(stubClassLookupResults.front());
}
}

if (stubActorDecl) {
// Effectively mangle the thunk as if it was declared on the $StubTarget
// type, rather than on a `protocol Target`.
appendContext(stubActorDecl, base, thunk->getAlternateModuleName());
} else {
// There's no need to replace the context, this is a normal concrete type
// remote call identifier.
appendContextOf(thunk, base);
}

Expand Down
4 changes: 2 additions & 2 deletions lib/IRGen/GenDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ void IRGenModule::emitDistributedTargetAccessor(ThunkOrRequirement target) {
IRGenMangler mangler;

addAccessibleFunction(AccessibleFunction::forDistributed(
mangler.mangleDistributedThunkRecord(targetDecl),
mangler.mangleDistributedThunk(targetDecl),
/*recordName=*/mangler.mangleDistributedThunkRecord(targetDecl),
/*accessorName=*/mangler.mangleDistributedThunk(targetDecl),
accessor.getTargetType(),
getAddrOfAsyncFunctionPointer(accessorRef)));
}
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6956,6 +6956,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::Identifiable:
case KnownProtocolKind::Actor:
case KnownProtocolKind::DistributedActor:
case KnownProtocolKind::DistributedActorStub:
case KnownProtocolKind::DistributedActorSystem:
case KnownProtocolKind::DistributedTargetInvocationEncoder:
case KnownProtocolKind::DistributedTargetInvocationDecoder:
Expand Down
114 changes: 84 additions & 30 deletions lib/Macros/Sources/SwiftMacros/DistributedResolvableMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ extension DistributedResolvableMacro {
return []
}

var isGenericStub = false
var isGenericOverActorSystem = false
var specificActorSystemRequirement: TypeSyntax?

let accessModifiers = proto.accessControlModifiers
Expand All @@ -140,15 +140,15 @@ extension DistributedResolvableMacro {
case .conformanceRequirement(let conformanceReq)
where conformanceReq.leftType.isActorSystem:
specificActorSystemRequirement = conformanceReq.rightType.trimmed
isGenericStub = true
isGenericOverActorSystem = true

case .sameTypeRequirement(let sameTypeReq):
switch sameTypeReq.leftType {
case .type(let type) where type.isActorSystem:
switch sameTypeReq.rightType.trimmed {
case .type(let rightType):
specificActorSystemRequirement = rightType
isGenericStub = false
isGenericOverActorSystem = false

case .expr:
throw DiagnosticsError(
Expand All @@ -167,41 +167,78 @@ extension DistributedResolvableMacro {
}
}

if isGenericStub, let specificActorSystemRequirement {
return [
"""
\(proto.modifiers) distributed actor $\(proto.name.trimmed)<ActorSystem>: \(proto.name.trimmed),
Distributed._DistributedActorStub
where ActorSystem: \(specificActorSystemRequirement)
{ }
"""
]
} else if let specificActorSystemRequirement {
return [
"""
\(proto.modifiers) distributed actor $\(proto.name.trimmed): \(proto.name.trimmed),
Distributed._DistributedActorStub
{
\(typealiasActorSystem(access: accessModifiers, proto, specificActorSystemRequirement))
}
"""
]
var primaryAssociatedTypes: [PrimaryAssociatedTypeSyntax] = []
if let primaryTypes = proto.primaryAssociatedTypeClause?.primaryAssociatedTypes {
primaryAssociatedTypes.append(contentsOf: primaryTypes)
}

// The $Stub is always generic over the actor system: $Stub<ActorSystem>
var primaryTypeParams: [String] = primaryAssociatedTypes.map {
$0.as(PrimaryAssociatedTypeSyntax.self)!.name.trimmed.text
}

// Don't duplicate the ActorSystem type parameter if it already was declared
// on the protocol as a primary associated type;
// otherwise, add it as first primary associated type.
let actorSystemTypeParam: [String]
if primaryTypeParams.contains("ActorSystem") {
actorSystemTypeParam = []
} else if isGenericOverActorSystem {
actorSystemTypeParam = ["ActorSystem"]
} else {
actorSystemTypeParam = []
}

// Prepend the actor system type parameter, as we want it to be the first one
primaryTypeParams = actorSystemTypeParam + primaryTypeParams
let typeParamsClause =
primaryTypeParams.isEmpty ? "" : "<" + primaryTypeParams.joined(separator: ", ") + ">"

var whereClause: String = ""
do {
let associatedTypeDecls = proto.associatedTypeDecls
var typeParamConstraints: [String] = []
for typeParamName in primaryTypeParams {
if let decl = associatedTypeDecls[typeParamName] {
if let inheritanceClause = decl.inheritanceClause {
typeParamConstraints.append("\(typeParamName)\(inheritanceClause)")
}
}
}

if isGenericOverActorSystem, let specificActorSystemRequirement {
typeParamConstraints = ["ActorSystem: \(specificActorSystemRequirement)"] + typeParamConstraints
}

if !typeParamConstraints.isEmpty {
whereClause += "\n where " + typeParamConstraints.joined(separator: ",\n ")
}
}

let stubActorBody: String
if isGenericOverActorSystem {
// there may be no `where` clause specifying an actor system,
// but perhaps there is a typealias (or extension with a typealias),
// specifying a concrete actor system so we let this synthesize
// an empty `$Greeter` -- this may fail, or succeed depending on
// surrounding code using a default distributed actor system,
// or extensions providing it.
return [
"""
\(proto.modifiers) distributed actor $\(proto.name.trimmed): \(proto.name.trimmed),
Distributed._DistributedActorStub
{
}
"""
]
stubActorBody = ""
} else if let specificActorSystemRequirement {
stubActorBody = "\(typealiasActorSystem(access: accessModifiers, proto, specificActorSystemRequirement))"
} else {
stubActorBody = ""
}

return [
"""
\(proto.modifiers) distributed actor $\(proto.name.trimmed)\(raw: typeParamsClause): \(proto.name.trimmed),
Distributed._DistributedActorStub \(raw: whereClause)
{
\(raw: stubActorBody)
}
"""
]
}

private static func typealiasActorSystem(access: DeclModifierListSyntax,
Expand Down Expand Up @@ -253,6 +290,23 @@ extension DeclModifierSyntax {
}
}

extension ProtocolDeclSyntax {
var associatedTypeDecls: [String: AssociatedTypeDeclSyntax] {
let visitor = AssociatedTypeDeclVisitor(viewMode: .all)
visitor.walk(self)
return visitor.associatedTypeDecls
}

final class AssociatedTypeDeclVisitor: SyntaxVisitor {
var associatedTypeDecls: [String: AssociatedTypeDeclSyntax] = [:]

override func visit(_ node: AssociatedTypeDeclSyntax) -> SyntaxVisitorContinueKind {
associatedTypeDecls[node.name.text] = node
return .skipChildren
}
}
}

// ===== -----------------------------------------------------------------------
// MARK: @Distributed.Resolvable macro errors

Expand Down
59 changes: 35 additions & 24 deletions stdlib/public/runtime/AccessibleFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "swift/Demangling/Demangler.h"
#include "swift/Runtime/AccessibleFunction.h"
#include "swift/Runtime/Concurrent.h"
#include "swift/Runtime/EnvironmentVariables.h"
#include "swift/Runtime/Metadata.h"
#include "swift/Threading/Once.h"
#include "Tracing.h"

#include <cstdint>
Expand Down Expand Up @@ -98,6 +100,29 @@ static Lazy<AccessibleFunctionsState> Functions;

} // end anonymous namespace

LLVM_ATTRIBUTE_UNUSED
static void _dumpAccessibleFunctionRecords(void *context) {
auto &S = Functions.get();

fprintf(stderr, "==== Accessible Function Records ====\n");
int count = 0;
for (const auto &section : S.SectionsToScan.snapshot()) {
for (auto &record : section) {
auto recordName =
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
auto demangledRecordName =
swift::Demangle::demangleSymbolAsString(recordName);
fprintf(stderr, "Record name: %s\n", recordName.data());
fprintf(stderr, " Demangled: %s\n", demangledRecordName.c_str());
fprintf(stderr, " Function Ptr: %p\n", record.Function.get());
fprintf(stderr, " Flags.IsDistributed: %d\n", record.Flags.isDistributed());
++count;
}
}
fprintf(stderr, "Record count: %d\n", count);
fprintf(stderr, "==== End of Accessible Function Records ====\n");
}

static void _registerAccessibleFunctions(AccessibleFunctionsState &C,
AccessibleFunctionsSection section) {
C.SectionsToScan.push_back(section);
Expand All @@ -119,27 +144,6 @@ void swift::addImageAccessibleFunctionsBlockCallback(
addImageAccessibleFunctionsBlockCallbackUnsafe(baseAddress, functions, size);
}

// TODO(distributed): expose dumping records via a flag
LLVM_ATTRIBUTE_UNUSED
static void _dumpAccessibleFunctionRecords() {
auto &S = Functions.get();

fprintf(stderr, "==== Accessible Function Records ====\n");
int count = 0;
for (const auto &section : S.SectionsToScan.snapshot()) {
for (auto &record : section) {
auto recordName =
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
fprintf(stderr, "Record name: %s\n", recordName.data());
fprintf(stderr, " Function Ptr: %p\n", record.Function.get());
fprintf(stderr, " Flags.IsDistributed: %d\n", record.Flags.isDistributed());
++count;
}
}
fprintf(stderr, "Record count: %d\n", count);
fprintf(stderr, "==== End of Accessible Function Records ====\n");
}

static const AccessibleFunctionRecord *
_searchForFunctionRecord(AccessibleFunctionsState &S, llvm::StringRef name) {
auto traceState = runtime::trace::accessible_function_scan_begin(name);
Expand All @@ -148,8 +152,9 @@ _searchForFunctionRecord(AccessibleFunctionsState &S, llvm::StringRef name) {
for (auto &record : section) {
auto recordName =
swift::Demangle::makeSymbolicMangledNameStringRef(record.Name.get());
if (recordName == name)
if (recordName == name) {
return traceState.end(&record);
}
}
}
return nullptr;
Expand All @@ -160,13 +165,19 @@ const AccessibleFunctionRecord *
swift::runtime::swift_findAccessibleFunction(const char *targetNameStart,
size_t targetNameLength) {
auto &S = Functions.get();

llvm::StringRef name{targetNameStart, targetNameLength};

if (swift::runtime::environment::SWIFT_DUMP_ACCESSIBLE_FUNCTIONS()) {
static swift::once_t dumpAccessibleFunctionsToken;
swift::once(dumpAccessibleFunctionsToken, _dumpAccessibleFunctionRecords, nullptr);
}

// Look for an existing entry.
{
auto snapshot = S.Cache.snapshot();
if (auto E = snapshot.find(name))
if (auto E = snapshot.find(name)) {
return E->getRecord();
}
}

// If entry doesn't exist (either record doesn't exist, hasn't been loaded, or
Expand Down
6 changes: 6 additions & 0 deletions stdlib/public/runtime/EnvironmentVariables.def
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,10 @@ VARIABLE(SWIFT_IS_CURRENT_EXECUTOR_LEGACY_MODE_OVERRIDE, string, "",
" 'legacy' (Legacy behavior), "
" 'swift6' (Swift 6.0+ behavior)")

VARIABLE(SWIFT_DUMP_ACCESSIBLE_FUNCTIONS, string, "",
"Dump a listing of all 'AccessibleFunctionRecord's upon first access. "
"These are used to obtain function pointers from accessible function "
"record names, e.g. by the Distributed runtime to invoke distributed "
"functions.")

#undef VARIABLE
Loading