Skip to content

Commit 7e671dd

Browse files
committed
Implement an SPI to iterate types who conform to some protocol
1 parent 81ed148 commit 7e671dd

File tree

6 files changed

+284
-1
lines changed

6 files changed

+284
-1
lines changed

stdlib/public/Reflection/Sources/Reflection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ list(APPEND SWIFT_REFLECTION_SWIFT_FLAGS
1717

1818
add_swift_target_library(swiftReflection ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
1919
Case.swift
20+
Conformances.swift
2021
Field.swift
2122
GenericArguments.swift
2223
KeyPath.swift
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
13+
14+
import Swift
15+
import _Runtime
16+
17+
@available(SwiftStdlib 5.9, *)
18+
@_silgen_name("_swift_reflection_withConformanceCache")
19+
func _withConformanceCache(
20+
_ proto: ProtocolDescriptor,
21+
_ context: UnsafeMutableRawPointer,
22+
_ callback: @convention(c) (
23+
/* Array of types */ UnsafePointer<UnsafeRawPointer>,
24+
/* Number of types */ Int,
25+
/* Context we just passed */ UnsafeMutableRawPointer
26+
) -> ()
27+
)
28+
29+
@available(SwiftStdlib 5.9, *)
30+
@_spi(Reflection)
31+
public func _typesThatConform(to type: Any.Type) -> [Any.Type]? {
32+
let meta = Metadata(type)
33+
34+
guard meta.kind == .existential else {
35+
return nil
36+
}
37+
38+
let existential = meta.existential
39+
40+
let protos = existential.protocols
41+
42+
guard protos.count == 1 else {
43+
return nil
44+
}
45+
46+
let proto = protos[0]
47+
48+
var result: [Any.Type] = []
49+
50+
withUnsafeMutablePointer(to: &result) {
51+
_withConformanceCache(proto, UnsafeMutableRawPointer($0)) {
52+
let buffer = UnsafeBufferPointer<Any.Type>(
53+
start: UnsafePointer<Any.Type>($0._rawValue),
54+
count: $1
55+
)
56+
57+
let arrayPtr = $2.assumingMemoryBound(to: [Any.Type].self)
58+
59+
arrayPtr.pointee = Array(buffer)
60+
}
61+
}
62+
63+
return result
64+
}
65+
66+
#endif

stdlib/public/Reflection/Sources/_Runtime/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ add_swift_target_library(swift_Runtime ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_ST
6464
Utils/RelativePointer.swift
6565
Utils/TypeCache.swift
6666

67+
Caches.cpp
6768
ConformanceDescriptor.swift
6869
ExistentialContainer.swift
6970
Functions.swift
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
// Only as SPI for Darwin platforms for right now...
13+
#if defined(__MACH__)
14+
15+
#include "swift/ABI/Metadata.h"
16+
#include "swift/Basic/Lazy.h"
17+
#include "swift/Runtime/Concurrent.h"
18+
#include "swift/Runtime/Config.h"
19+
#include "swift/Runtime/ImageInspection.h"
20+
#include "swift/Runtime/ImageInspectionMachO.h"
21+
#include <cstdint>
22+
#include <unordered_map>
23+
24+
using namespace swift;
25+
26+
//===----------------------------------------------------------------------===//
27+
// Protocol Conformance Cache
28+
//===----------------------------------------------------------------------===//
29+
30+
namespace {
31+
struct ConformanceSection {
32+
const ProtocolConformanceRecord *Begin, *End;
33+
34+
ConformanceSection(const ProtocolConformanceRecord *begin,
35+
const ProtocolConformanceRecord *end)
36+
: Begin(begin), End(end) {}
37+
38+
ConformanceSection(const void *ptr, uintptr_t size) {
39+
auto bytes = reinterpret_cast<const char *>(ptr);
40+
Begin = reinterpret_cast<const ProtocolConformanceRecord *>(ptr);
41+
End = reinterpret_cast<const ProtocolConformanceRecord *>(bytes + size);
42+
}
43+
44+
const ProtocolConformanceRecord *begin() const {
45+
return Begin;
46+
}
47+
const ProtocolConformanceRecord *end() const {
48+
return End;
49+
}
50+
};
51+
}
52+
53+
struct ConformanceCache {
54+
// All accesses to Cache and LastSectionCount must be within CacheMutex's
55+
// lock scope.
56+
57+
std::unordered_map<
58+
/* Key */ const ProtocolDescriptor *,
59+
/* Value */ std::vector<const Metadata *>
60+
> Cache;
61+
Mutex CacheMutex;
62+
ConcurrentReadableArray<ConformanceSection> Sections;
63+
64+
size_t LastSectionCount = 0;
65+
66+
ConformanceCache() {
67+
initializeProtocolConformanceLookup();
68+
}
69+
};
70+
71+
static Lazy<ConformanceCache> Conformances;
72+
73+
void swift::addImageProtocolConformanceBlockCallbackUnsafe(
74+
const void *baseAddress,
75+
const void *conformances,
76+
uintptr_t conformancesSize) {
77+
assert(conformancesSize % sizeof(ProtocolConformanceRecord) == 0 &&
78+
"conformances section not a multiple of ProtocolConformanceRecord");
79+
80+
// Conformance cache should always be sufficiently initialized by this point.
81+
auto &C = Conformances.unsafeGetAlreadyInitialized();
82+
83+
C.Sections.push_back(ConformanceSection{conformances, conformancesSize});
84+
}
85+
86+
// WARNING: the callbacks are called from unsafe contexts (with the dyld and
87+
// ObjC runtime locks held) and must be very careful in what they do. Locking
88+
// must be arranged to avoid deadlocks (other code must never call out to dyld
89+
// or ObjC holding a lock that gets taken in one of these callbacks) and the
90+
// new/delete operators must not be called, in case a program supplies an
91+
// overload which does not cooperate with these requirements.
92+
93+
void swift::initializeProtocolConformanceLookup() {
94+
REGISTER_FUNC(
95+
addImageCallback<TextSegment, ProtocolConformancesSection,
96+
addImageProtocolConformanceBlockCallbackUnsafe>);
97+
}
98+
99+
static const Metadata *_getCanonicalTypeMetadata(
100+
const ProtocolConformanceDescriptor *conformance) {
101+
switch (conformance->getTypeKind()) {
102+
case TypeReferenceKind::DirectTypeDescriptor:
103+
case TypeReferenceKind::IndirectTypeDescriptor: {
104+
if (auto anyType = conformance->getTypeDescriptor()) {
105+
if (auto type = dyn_cast<TypeContextDescriptor>(anyType)) {
106+
if (!type->isGeneric()) {
107+
if (auto accessFn = type->getAccessFunction()) {
108+
return accessFn(MetadataState::Abstract).Value;
109+
}
110+
}
111+
}
112+
}
113+
114+
return nullptr;
115+
}
116+
117+
case TypeReferenceKind::DirectObjCClassName:
118+
case TypeReferenceKind::IndirectObjCClass:
119+
return nullptr;
120+
}
121+
}
122+
123+
using ConformanceCacheCallback = void (*)(const Metadata **,
124+
size_t, void *);
125+
126+
SWIFT_RUNTIME_STDLIB_SPI SWIFT_CC(swift)
127+
void _swift_reflection_withConformanceCache(const ProtocolDescriptor *proto,
128+
void *context,
129+
ConformanceCacheCallback callback) {
130+
auto &C = Conformances.get();
131+
132+
auto snapshot = C.Sections.snapshot();
133+
134+
Mutex::ScopedLock lock(C.CacheMutex);
135+
136+
if (C.LastSectionCount > 0 && snapshot.count() <= C.LastSectionCount) {
137+
auto entry = C.Cache.find(proto);
138+
139+
if (entry != C.Cache.end()) {
140+
callback(entry->second.data(), entry->second.size(), context);
141+
return;
142+
}
143+
}
144+
145+
std::vector<const Metadata *> types = {};
146+
147+
for (auto &section : snapshot) {
148+
for (auto &record : section) {
149+
auto conformance = record.get();
150+
151+
if (conformance->getProtocol() != proto) {
152+
continue;
153+
}
154+
155+
if (conformance->hasConditionalRequirements()) {
156+
continue;
157+
}
158+
159+
if (auto type = _getCanonicalTypeMetadata(conformance)) {
160+
types.push_back(type);
161+
}
162+
}
163+
}
164+
165+
callback(types.data(), types.size(), context);
166+
167+
C.Cache[proto] = types;
168+
C.LastSectionCount = snapshot.count();
169+
}
170+
171+
#endif

stdlib/public/Reflection/Sources/_Runtime/Metadata/Metadata.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ extension Metadata {
5353
public var `enum`: EnumMetadata {
5454
EnumMetadata(ptr)
5555
}
56-
56+
57+
@inlinable
58+
public var existential: ExistentialMetadata {
59+
ExistentialMetadata(ptr)
60+
}
61+
5762
@inlinable
5863
public var extendedExistential: ExtendedExistentialMetadata {
5964
ExtendedExistentialMetadata(ptr)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
// REQUIRES: reflection_runtime
4+
// REQUIRES: OS=macosx || OS=ios || OS=tvos || OS=watchos
5+
// UNSUPPORTED: freestanding
6+
7+
import StdlibUnittest
8+
9+
@_spi(Reflection)
10+
import Reflection
11+
12+
let suite = TestSuite("Conformances")
13+
14+
protocol MyCustomProtocol {}
15+
16+
struct Conformer: MyCustomProtocol {}
17+
struct NotConformer {}
18+
struct GenericConformer<T>: MyCustomProtocol {}
19+
struct GenericConditionalConformer<T> {}
20+
21+
extension GenericConditionalConformer: MyCustomProtocol where T: MyCustomProtocol {}
22+
23+
if #available(SwiftStdlib 5.9, *) {
24+
suite.test("Basic") {
25+
let customTypes = _typesThatConform(to: (any MyCustomProtocol).self)
26+
27+
expectNotNil(customTypes)
28+
expectEqual(customTypes!.count, 1)
29+
expectTrue(customTypes![0] == Conformer.self)
30+
31+
let kpType = _typesThatConform(to: (any _AppendKeyPath).self)
32+
33+
expectNotNil(kpType)
34+
expectEqual(kpType!.count, 1)
35+
expectTrue(kpType![0] == AnyKeyPath.self)
36+
}
37+
}
38+
39+
runAllTests()

0 commit comments

Comments
 (0)