Skip to content

Commit 4d262df

Browse files
authored
Merge pull request #77004 from eeckstein/fix-vtable-specialization
embedded: fix several issues with vtable specialization
2 parents 3aed095 + 5c8fe55 commit 4d262df

File tree

12 files changed

+209
-84
lines changed

12 files changed

+209
-84
lines changed

SwiftCompilerSources/Sources/Optimizer/ModulePasses/MandatoryPerformanceOptimizations.swift

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ let mandatoryPerformanceOptimizations = ModulePass(name: "mandatory-performance-
3333
// For embedded Swift, optimize all the functions (there cannot be any
3434
// generics, type metadata, etc.)
3535
if moduleContext.options.enableEmbeddedSwift {
36-
// We need to specialize all vtables which are referenced from non-generic contexts. Beside
37-
// `alloc_ref`s of generic classes in non-generic functions, we also need to specialize generic
38-
// superclasses of non-generic classes. E.g. `class Derived : Base<Int> {}`
39-
specializeVTablesOfSuperclasses(moduleContext)
40-
4136
worklist.addAllNonGenericFunctions(of: moduleContext)
4237
} else {
4338
worklist.addAllPerformanceAnnotatedFunctions(of: moduleContext)
@@ -102,15 +97,18 @@ private func optimize(function: Function, _ context: FunctionPassContext, _ modu
10297
// Embedded Swift specific transformations
10398
case let alloc as AllocRefInst:
10499
if context.options.enableEmbeddedSwift {
105-
specializeVTableAndAddEntriesToWorklist(for: alloc.type, in: function,
106-
errorLocation: alloc.location,
107-
moduleContext, &worklist)
100+
specializeVTable(forClassType: alloc.type, errorLocation: alloc.location, moduleContext) {
101+
worklist.pushIfNotVisited($0)
102+
}
108103
}
109104
case let metatype as MetatypeInst:
110105
if context.options.enableEmbeddedSwift {
111-
specializeVTableAndAddEntriesToWorklist(for: metatype.type, in: function,
112-
errorLocation: metatype.location,
113-
moduleContext, &worklist)
106+
let instanceType = metatype.type.loweredInstanceTypeOfMetatype(in: function)
107+
if instanceType.isClass {
108+
specializeVTable(forClassType: instanceType, errorLocation: metatype.location, moduleContext) {
109+
worklist.pushIfNotVisited($0)
110+
}
111+
}
114112
}
115113
case let classMethod as ClassMethodInst:
116114
if context.options.enableEmbeddedSwift {
@@ -166,29 +164,6 @@ private func optimize(function: Function, _ context: FunctionPassContext, _ modu
166164
}
167165
}
168166

169-
private func specializeVTableAndAddEntriesToWorklist(for type: Type, in function: Function,
170-
errorLocation: Location,
171-
_ moduleContext: ModulePassContext,
172-
_ worklist: inout FunctionWorklist) {
173-
let vTablesCountBefore = moduleContext.vTables.count
174-
175-
guard specializeVTable(forClassType: type, errorLocation: errorLocation, moduleContext) != nil else {
176-
return
177-
}
178-
179-
// More than one new vtable might have been created (superclasses), process them all
180-
let vTables = moduleContext.vTables
181-
for i in vTablesCountBefore ..< vTables.count {
182-
for entry in vTables[i].entries
183-
// A new vtable can still contain a generic function if the method couldn't be specialized for some reason
184-
// and an error has been printed. Exclude generic functions to not run into an assert later.
185-
where !entry.implementation.isGeneric
186-
{
187-
worklist.pushIfNotVisited(entry.implementation)
188-
}
189-
}
190-
}
191-
192167
private func inlineAndDevirtualize(apply: FullApplySite, alreadyInlinedFunctions: inout Set<PathFunctionTuple>,
193168
_ context: FunctionPassContext, _ simplifyCtxt: SimplifyContext) {
194169
// De-virtualization and inlining in/into a "serialized" function might create function references to functions

SwiftCompilerSources/Sources/Optimizer/PassManager/ModulePassContext.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ struct ModulePassContext : Context, CustomStringConvertible {
168168
return VTable(bridged: bridgedVTable)
169169
}
170170

171+
func replaceVTableEntries(of vTable: VTable, with entries: [VTable.Entry]) {
172+
let bridgedEntries = entries.map { $0.bridged }
173+
bridgedEntries.withBridgedArrayRef {
174+
vTable.bridged.replaceEntries($0)
175+
}
176+
notifyFunctionTablesChanged()
177+
}
178+
171179
func createEmptyFunction(
172180
name: String,
173181
parameters: [ParameterInfo],
@@ -201,6 +209,10 @@ struct ModulePassContext : Context, CustomStringConvertible {
201209
}
202210
}
203211
}
212+
213+
func notifyFunctionTablesChanged() {
214+
_bridged.asNotificationHandler().notifyChanges(.functionTablesChanged)
215+
}
204216
}
205217

206218
extension GlobalVariable {

SwiftCompilerSources/Sources/Optimizer/Utilities/GenericSpecialization.swift

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,65 +13,92 @@
1313
import AST
1414
import SIL
1515

16-
@discardableResult
1716
func specializeVTable(forClassType classType: Type,
1817
errorLocation: Location,
19-
_ context: ModulePassContext) -> VTable?
18+
_ context: ModulePassContext,
19+
notifyNewFunction: (Function) -> ())
2020
{
21-
guard let nominal = classType.nominal,
22-
let classDecl = nominal as? ClassDecl,
23-
classType.isGenericAtAnyLevel else
24-
{
25-
return nil
26-
}
21+
var specializer = VTableSpecializer(errorLocation: errorLocation, context)
22+
specializer.specializeVTable(forClassType: classType, notifyNewFunction)
23+
}
2724

28-
if context.lookupSpecializedVTable(for: classType) != nil {
29-
return nil
30-
}
25+
private struct VTableSpecializer {
26+
let errorLocation: Location
27+
let context: ModulePassContext
3128

32-
guard let origVTable = context.lookupVTable(for: classDecl) else {
33-
context.diagnosticEngine.diagnose(errorLocation.sourceLoc, .cannot_specialize_class, classType)
34-
return nil
35-
}
29+
// The type of the first class in the hierarchy which implements a method
30+
private var baseTypesOfMethods = Dictionary<Function, Type>()
3631

37-
let classContextSubs = classType.contextSubstitutionMap
32+
init(errorLocation: Location, _ context: ModulePassContext) {
33+
self.errorLocation = errorLocation
34+
self.context = context
35+
}
3836

39-
let newEntries = origVTable.entries.map { origEntry in
40-
if !origEntry.implementation.isGeneric {
41-
return origEntry
37+
mutating func specializeVTable(forClassType classType: Type, _ notifyNewFunction: (Function) -> ()) {
38+
// First handle super classes.
39+
// This is also required for non-generic classes - in case a superclass is generic, e.g.
40+
// `class Derived : Base<Int> {}` - for two reasons:
41+
// * A vtable of a derived class references the vtable of the super class. And of course the referenced
42+
// super-class vtable needs to be a specialized vtable.
43+
// * Even a non-generic derived class can contain generic methods of the base class in case a base-class
44+
// method is not overridden.
45+
//
46+
if let superClassTy = classType.superClassType {
47+
specializeVTable(forClassType: superClassTy, notifyNewFunction)
4248
}
43-
let methodSubs = classContextSubs.getMethodSubstitutions(for: origEntry.implementation)
4449

45-
guard !methodSubs.conformances.contains(where: {!$0.isValid}),
46-
let specializedMethod = context.specialize(function: origEntry.implementation, for: methodSubs) else
47-
{
48-
context.diagnosticEngine.diagnose(origEntry.methodDecl.location.sourceLoc, .non_final_generic_class_function)
49-
return origEntry
50+
let classDecl = classType.nominal! as! ClassDecl
51+
guard let origVTable = context.lookupVTable(for: classDecl) else {
52+
context.diagnosticEngine.diagnose(errorLocation.sourceLoc, .cannot_specialize_class, classType)
53+
return
5054
}
5155

52-
context.deserializeAllCallees(of: specializedMethod, mode: .allFunctions)
53-
specializedMethod.set(linkage: .public, context)
54-
specializedMethod.set(isSerialized: false, context)
56+
for entry in origVTable.entries {
57+
if baseTypesOfMethods[entry.implementation] == nil {
58+
baseTypesOfMethods[entry.implementation] = classType
59+
}
60+
}
5561

56-
return VTable.Entry(kind: origEntry.kind, isNonOverridden: origEntry.isNonOverridden,
57-
methodDecl: origEntry.methodDecl, implementation: specializedMethod)
62+
if classType.isGenericAtAnyLevel {
63+
if context.lookupSpecializedVTable(for: classType) != nil {
64+
// We already specialized the vtable
65+
return
66+
}
67+
let newEntries = specializeEntries(of: origVTable, notifyNewFunction)
68+
context.createSpecializedVTable(entries: newEntries, for: classType, isSerialized: false)
69+
} else {
70+
if !origVTable.entries.contains(where: { $0.implementation.isGeneric }) {
71+
// The vtable (of the non-generic class) doesn't contain any generic functions (from a generic base class).
72+
return
73+
}
74+
let newEntries = specializeEntries(of: origVTable, notifyNewFunction)
75+
context.replaceVTableEntries(of: origVTable, with: newEntries)
76+
}
5877
}
5978

60-
let specializedVTable = context.createSpecializedVTable(entries: newEntries, for: classType, isSerialized: false)
61-
if let superClassTy = classType.superClassType {
62-
specializeVTable(forClassType: superClassTy, errorLocation: classDecl.location, context)
63-
}
64-
return specializedVTable
65-
}
79+
private func specializeEntries(of vTable: VTable, _ notifyNewFunction: (Function) -> ()) -> [VTable.Entry] {
80+
return vTable.entries.compactMap { entry in
81+
if !entry.implementation.isGeneric {
82+
return entry
83+
}
84+
let baseType = baseTypesOfMethods[entry.implementation]!
85+
let classContextSubs = baseType.contextSubstitutionMap
86+
let methodSubs = classContextSubs.getMethodSubstitutions(for: entry.implementation)
87+
88+
guard !methodSubs.conformances.contains(where: {!$0.isValid}),
89+
let specializedMethod = context.specialize(function: entry.implementation, for: methodSubs) else
90+
{
91+
context.diagnosticEngine.diagnose(entry.methodDecl.location.sourceLoc, .non_final_generic_class_function)
92+
return nil
93+
}
94+
notifyNewFunction(specializedMethod)
95+
96+
context.deserializeAllCallees(of: specializedMethod, mode: .allFunctions)
97+
specializedMethod.set(linkage: .public, context)
98+
specializedMethod.set(isSerialized: false, context)
6699

67-
func specializeVTablesOfSuperclasses(_ moduleContext: ModulePassContext) {
68-
for vtable in moduleContext.vTables {
69-
if !vtable.isSpecialized,
70-
!vtable.class.isGenericAtAnyLevel,
71-
let superClassTy = vtable.class.superClassType,
72-
superClassTy.isGenericAtAnyLevel
73-
{
74-
specializeVTable(forClassType: superClassTy, errorLocation: vtable.class.location, moduleContext)
100+
return VTable.Entry(kind: entry.kind, isNonOverridden: entry.isNonOverridden,
101+
methodDecl: entry.methodDecl, implementation: specializedMethod)
75102
}
76103
}
77104
}

SwiftCompilerSources/Sources/SIL/VTable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import AST
1414
import SILBridging
1515

1616
public struct VTable : CustomStringConvertible, NoReflectionChildren {
17-
let bridged: BridgedVTable
17+
public let bridged: BridgedVTable
1818

1919
public init(bridged: BridgedVTable) { self.bridged = bridged }
2020

include/swift/SIL/SILBridging.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,17 +1065,18 @@ struct BridgedVTableEntryArray {
10651065
};
10661066

10671067
struct BridgedVTable {
1068-
const swift::SILVTable * _Nonnull vTable;
1068+
swift::SILVTable * _Nonnull vTable;
10691069

10701070
BridgedOwnedString getDebugDescription() const;
10711071
BRIDGED_INLINE SwiftInt getNumEntries() const;
10721072
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedVTableEntry getEntry(SwiftInt index) const;
10731073
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedDeclObj getClass() const;
10741074
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType getSpecializedClassType() const;
1075+
BRIDGED_INLINE void replaceEntries(BridgedArrayRef bridgedEntries) const;
10751076
};
10761077

10771078
struct OptionalBridgedVTable {
1078-
const swift::SILVTable * _Nullable table;
1079+
swift::SILVTable * _Nullable table;
10791080
};
10801081

10811082
struct BridgedWitnessTableEntry {
@@ -1328,7 +1329,8 @@ struct BridgedChangeNotificationHandler {
13281329
instructionsChanged,
13291330
callsChanged,
13301331
branchesChanged,
1331-
effectsChanged
1332+
effectsChanged,
1333+
functionTablesChanged
13321334
};
13331335

13341336
void notifyChanges(Kind changeKind) const;

include/swift/SIL/SILBridgingImpl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,14 @@ BridgedType BridgedVTable::getSpecializedClassType() const {
17151715
return {vTable->getClassType()};
17161716
}
17171717

1718+
void BridgedVTable::replaceEntries(BridgedArrayRef bridgedEntries) const {
1719+
llvm::SmallVector<swift::SILVTableEntry, 8> entries;
1720+
for (const BridgedVTableEntry &e : bridgedEntries.unbridged<BridgedVTableEntry>()) {
1721+
entries.push_back(e.unbridged());
1722+
}
1723+
vTable->replaceEntries(entries);
1724+
}
1725+
17181726
//===----------------------------------------------------------------------===//
17191727
// BridgedWitnessTable, BridgedDefaultWitnessTable
17201728
//===----------------------------------------------------------------------===//

include/swift/SIL/SILVTable.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class SILVTable final : public SILAllocated<SILVTable>,
205205
NumEntries = std::distance(Entries.begin(), end);
206206
}
207207

208+
void replaceEntries(ArrayRef<Entry> newEntries);
209+
208210
/// Verify that the vtable is well-formed for the given class.
209211
void verify(const SILModule &M) const;
210212

include/swift/SILOptimizer/PassManager/PassManager.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class SwiftPassInvocation {
6969
SILAnalysis::InvalidationKind changeNotifications =
7070
SILAnalysis::InvalidationKind::Nothing;
7171

72+
bool functionTablesChanged = false;
73+
7274
/// All slabs, allocated by the pass.
7375
SILModule::SlabList allocatedSlabs;
7476

@@ -141,6 +143,8 @@ class SwiftPassInvocation {
141143
/// Called by the pass when changes are made to the SIL.
142144
void notifyChanges(SILAnalysis::InvalidationKind invalidationKind);
143145

146+
void notifyFunctionTablesChanged();
147+
144148
/// Called by the pass manager before the pass starts running.
145149
void startModulePassRun(SILModuleTransform *transform);
146150

@@ -513,6 +517,9 @@ notifyChanges(SILAnalysis::InvalidationKind invalidationKind) {
513517
changeNotifications = (SILAnalysis::InvalidationKind)
514518
(changeNotifications | invalidationKind);
515519
}
520+
inline void SwiftPassInvocation::notifyFunctionTablesChanged() {
521+
functionTablesChanged = true;
522+
}
516523

517524
} // end namespace swift
518525

lib/SIL/IR/SILVTable.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,22 @@ void SILVTable::updateVTableCache(const Entry &entry) {
7575
M.VTableEntryCache[{this, entry.getMethod()}] = entry;
7676
}
7777

78+
void SILVTable::replaceEntries(ArrayRef<Entry> newEntries) {
79+
auto entries = getMutableEntries();
80+
ASSERT(newEntries.size() <= entries.size());
81+
for (unsigned i = 0; i < entries.size(); ++i) {
82+
entries[i].getImplementation()->decrementRefCount();
83+
if (i < newEntries.size()) {
84+
entries[i] = newEntries[i];
85+
entries[i].getImplementation()->incrementRefCount();
86+
updateVTableCache(entries[i]);
87+
} else {
88+
removeFromVTableCache(entries[i]);
89+
}
90+
}
91+
NumEntries = newEntries.size();
92+
}
93+
7894
SILVTable::SILVTable(ClassDecl *c, SILType classType,
7995
SerializedKind_t serialized, ArrayRef<Entry> entries)
8096
: Class(c), classType(classType), SerializedKind(serialized),

lib/SILOptimizer/PassManager/PassManager.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,7 @@ void SwiftPassInvocation::finishedModulePassRun() {
15161516
endPass();
15171517
assert(!function && transform && "not running a pass");
15181518
assert(changeNotifications == SILAnalysis::InvalidationKind::Nothing
1519+
&& !functionTablesChanged
15191520
&& "unhandled change notifications at end of module pass");
15201521
transform = nullptr;
15211522
}
@@ -1551,6 +1552,7 @@ void SwiftPassInvocation::endPass() {
15511552
void SwiftPassInvocation::beginTransformFunction(SILFunction *function) {
15521553
assert(!this->function && transform && "not running a pass");
15531554
assert(changeNotifications == SILAnalysis::InvalidationKind::Nothing
1555+
&& !functionTablesChanged
15541556
&& "change notifications not cleared");
15551557
this->function = function;
15561558
}
@@ -1561,6 +1563,10 @@ void SwiftPassInvocation::endTransformFunction() {
15611563
passManager->invalidateAnalysis(function, changeNotifications);
15621564
changeNotifications = SILAnalysis::InvalidationKind::Nothing;
15631565
}
1566+
if (functionTablesChanged) {
1567+
passManager->invalidateFunctionTables();
1568+
functionTablesChanged = false;
1569+
}
15641570
function = nullptr;
15651571
assert(numBlockSetsAllocated == 0 && "Not all BasicBlockSets deallocated");
15661572
assert(numNodeSetsAllocated == 0 && "Not all NodeSets deallocated");
@@ -1580,6 +1586,7 @@ void SwiftPassInvocation::endVerifyFunction() {
15801586
assert(function);
15811587
if (!transform) {
15821588
assert(changeNotifications == SILAnalysis::InvalidationKind::Nothing &&
1589+
!functionTablesChanged &&
15831590
"verifyication must not change the SIL of a function");
15841591
assert(numBlockSetsAllocated == 0 && "Not all BasicBlockSets deallocated");
15851592
assert(numNodeSetsAllocated == 0 && "Not all NodeSets deallocated");
@@ -1634,6 +1641,9 @@ void BridgedChangeNotificationHandler::notifyChanges(Kind changeKind) const {
16341641
case Kind::effectsChanged:
16351642
invocation->notifyChanges(SILAnalysis::InvalidationKind::Effects);
16361643
break;
1644+
case Kind::functionTablesChanged:
1645+
invocation->notifyFunctionTablesChanged();
1646+
break;
16371647
}
16381648
}
16391649

lib/SILOptimizer/Utils/Generics.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,6 +2324,9 @@ bool swift::specializeClassMethodInst(ClassMethodInst *cm) {
23242324

23252325
SILValue instance = cm->getOperand();
23262326
SILType classTy = instance->getType();
2327+
if (classTy.is<MetatypeType>())
2328+
classTy = classTy.getLoweredInstanceTypeOfMetatype(cm->getFunction());
2329+
23272330
CanType astType = classTy.getASTType();
23282331
if (!astType->isSpecialized())
23292332
return false;

0 commit comments

Comments
 (0)