Skip to content

embedded: support class existentials with generic classes #76669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,21 @@ struct FunctionUses {

for vTable in context.vTables {
for entry in vTable.entries {
markUnknown(entry.function)
markUnknown(entry.implementation)
}
}

for witnessTable in context.witnessTables {
for entry in witnessTable.entries {
if entry.kind == .Method, let f = entry.methodFunction {
if entry.kind == .method, let f = entry.methodFunction {
markUnknown(f)
}
}
}

for witnessTable in context.defaultWitnessTables {
for entry in witnessTable.entries {
if entry.kind == .Method, let f = entry.methodFunction {
if entry.kind == .method, let f = entry.methodFunction {
markUnknown(f)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ private func insertCompensatingInstructions(for inst: Instruction, in failureBlo
let newInst: SingleValueInstruction
switch inst {
case let ier as InitExistentialRefInst:
newInst = builder.createInitExistentialRef(instance: newArg, existentialType: ier.type, useConformancesOf: ier)
newInst = builder.createInitExistentialRef(instance: newArg,
existentialType: ier.type,
formalConcreteType: ier.formalConcreteType,
conformances: ier.conformances)
case let uc as UpcastInst:
newInst = builder.createUpcast(from: newArg, to: uc.type)
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,23 @@ let mandatoryPerformanceOptimizations = ModulePass(name: "mandatory-performance-
// For embedded Swift, optimize all the functions (there cannot be any
// generics, type metadata, etc.)
if moduleContext.options.enableEmbeddedSwift {
// We need to specialize all vtables which are referenced from non-generic contexts. Beside
// `alloc_ref`s of generic classes in non-generic functions, we also need to specialize generic
// superclasses of non-generic classes. E.g. `class Derived : Base<Int> {}`
specializeVTablesOfSuperclasses(moduleContext)

worklist.addAllNonGenericFunctions(of: moduleContext)
} else {
worklist.addAllPerformanceAnnotatedFunctions(of: moduleContext)
worklist.addAllAnnotatedGlobalInitOnceFunctions(of: moduleContext)
}

optimizeFunctionsTopDown(using: &worklist, moduleContext)

if moduleContext.options.enableEmbeddedSwift {
// Print errors for generic functions in vtables, which is not allowed in embedded Swift.
checkVTablesForGenericFunctions(moduleContext)
}
}

private func optimizeFunctionsTopDown(using worklist: inout FunctionWorklist,
Expand Down Expand Up @@ -92,17 +102,26 @@ private func optimize(function: Function, _ context: FunctionPassContext, _ modu
// Embedded Swift specific transformations
case let alloc as AllocRefInst:
if context.options.enableEmbeddedSwift {
specializeVTableAndAddEntriesToWorklist(for: alloc.type, in: function, context, moduleContext, &worklist)
specializeVTableAndAddEntriesToWorklist(for: alloc.type, in: function,
errorLocation: alloc.location,
moduleContext, &worklist)
}
case let metatype as MetatypeInst:
if context.options.enableEmbeddedSwift {
specializeVTableAndAddEntriesToWorklist(for: metatype.type, in: function, context, moduleContext, &worklist)
specializeVTableAndAddEntriesToWorklist(for: metatype.type, in: function,
errorLocation: metatype.location,
moduleContext, &worklist)
}
case let classMethod as ClassMethodInst:
if context.options.enableEmbeddedSwift {
_ = context.specializeClassMethodInst(classMethod)
}

case let initExRef as InitExistentialRefInst:
if context.options.enableEmbeddedSwift {
specializeWitnessTables(for: initExRef, moduleContext, &worklist)
}

// We need to de-virtualize deinits of non-copyable types to be able to specialize the deinitializers.
case let destroyValue as DestroyValueInst:
if !devirtualizeDeinits(of: destroyValue, simplifyCtxt) {
Expand Down Expand Up @@ -144,19 +163,24 @@ private func optimize(function: Function, _ context: FunctionPassContext, _ modu
}

private func specializeVTableAndAddEntriesToWorklist(for type: Type, in function: Function,
_ context: FunctionPassContext, _ moduleContext: ModulePassContext,
errorLocation: Location,
_ moduleContext: ModulePassContext,
_ worklist: inout FunctionWorklist) {
let vTablesCountBefore = moduleContext.vTables.count

guard context.specializeVTable(for: type, in: function) != nil else {
guard specializeVTable(forClassType: type, errorLocation: errorLocation, moduleContext) != nil else {
return
}

// More than one new vtable might have been created (superclasses), process them all
let vTables = moduleContext.vTables
for i in vTablesCountBefore ..< vTables.count {
for entry in vTables[i].entries {
worklist.pushIfNotVisited(entry.function)
for entry in vTables[i].entries
// A new vtable can still contain a generic function if the method couldn't be specialized for some reason
// and an error has been printed. Exclude generic functions to not run into an assert later.
where !entry.implementation.isGeneric
{
worklist.pushIfNotVisited(entry.implementation)
}
}
}
Expand Down Expand Up @@ -240,6 +264,45 @@ private func shouldInline(apply: FullApplySite, callee: Function, alreadyInlined
return false
}

private func specializeWitnessTables(for initExRef: InitExistentialRefInst, _ context: ModulePassContext,
_ worklist: inout FunctionWorklist)
{
for conformance in initExRef.conformances where conformance.isConcrete {
let origWitnessTable = context.lookupWitnessTable(for: conformance)
if conformance.isSpecialized {
if origWitnessTable == nil {
let wt = specializeWitnessTable(forConformance: conformance, errorLocation: initExRef.location, context)
worklist.addWitnessMethods(of: wt)
}
} else if let origWitnessTable {
checkForGenericMethods(in: origWitnessTable, errorLocation: initExRef.location, context)
}
}
}

private func checkForGenericMethods(in witnessTable: WitnessTable,
errorLocation: Location,
_ context: ModulePassContext)
{
for entry in witnessTable.entries where entry.kind == .method {
if let method = entry.methodFunction,
method.isGeneric
{
context.diagnosticEngine.diagnose(errorLocation.sourceLoc, .cannot_specialize_witness_method,
entry.methodRequirement)
return
}
}
}

private func checkVTablesForGenericFunctions(_ context: ModulePassContext) {
for vTable in context.vTables where !vTable.class.isGenericAtAnyLevel {
for entry in vTable.entries where entry.implementation.isGeneric {
context.diagnosticEngine.diagnose(entry.methodDecl.location.sourceLoc, .non_final_generic_class_function)
}
}
}

private extension FullApplySite {
func resultIsUsedInGlobalInitialization() -> SmallProjectionPath? {
guard parentFunction.isGlobalInitOnceFunction,
Expand Down Expand Up @@ -445,6 +508,18 @@ fileprivate struct FunctionWorklist {
}
}

mutating func addWitnessMethods(of witnessTable: WitnessTable) {
for entry in witnessTable.entries where entry.kind == .method {
if let method = entry.methodFunction,
// A new witness table can still contain a generic function if the method couldn't be specialized for
// some reason and an error has been printed. Exclude generic functions to not run into an assert later.
!method.isGeneric
{
pushIfNotVisited(method)
}
}
}

mutating func pushIfNotVisited(_ element: Function) {
if pushedFunctions.insert(element).inserted {
functions.append(element)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ extension Context {
}
}

func lookupWitnessTable(for conformance: ProtocolConformance) -> WitnessTable? {
return _bridged.lookupWitnessTable(conformance.bridged).witnessTable
}

func lookupVTable(for classDecl: NominalTypeDecl) -> VTable? {
return _bridged.lookupVTable(classDecl.bridged).vTable
}

func lookupSpecializedVTable(for classType: Type) -> VTable? {
return _bridged.lookupSpecializedVTable(classType.bridged).vTable
}

func notifyNewFunction(function: Function, derivedFrom: Function) {
_bridged.addFunctionToPassManagerWorklist(function.bridged, derivedFrom.bridged)
}
Expand Down Expand Up @@ -221,7 +233,7 @@ extension MutatingContext {
}

func getContextSubstitutionMap(for type: Type) -> SubstitutionMap {
SubstitutionMap(_bridged.getContextSubstitutionMap(type.bridged))
SubstitutionMap(bridged: _bridged.getContextSubstitutionMap(type.bridged))
}

func notifyInstructionsChanged() {
Expand Down Expand Up @@ -327,13 +339,6 @@ struct FunctionPassContext : MutatingContext {
return false
}

func specializeVTable(for type: Type, in function: Function) -> VTable? {
guard let vtablePtr = _bridged.specializeVTableForType(type.bridged, function.bridged) else {
return nil
}
return VTable(bridged: BridgedVTable(vTable: vtablePtr))
}

func specializeClassMethodInst(_ cm: ClassMethodInst) -> Bool {
if _bridged.specializeClassMethodInst(cm.bridged) {
notifyInstructionsChanged()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ struct ModulePassContext : Context, CustomStringConvertible {
}

struct VTableArray : BridgedRandomAccessCollection {
fileprivate let bridged: BridgedPassContext.VTableArray
fileprivate let bridgedCtxt: BridgedPassContext

var startIndex: Int { return 0 }
var endIndex: Int { return bridged.count }
var endIndex: Int { return bridgedCtxt.getNumVTables() }

subscript(_ index: Int) -> VTable {
assert(index >= startIndex && index < endIndex)
return VTable(bridged: BridgedVTable(vTable: bridged.base![index]))
return VTable(bridged: bridgedCtxt.getVTable(index))
}
}

Expand Down Expand Up @@ -101,9 +101,7 @@ struct ModulePassContext : Context, CustomStringConvertible {
GlobalVariableList(first: _bridged.getFirstGlobalInModule().globalVar)
}

var vTables: VTableArray {
VTableArray(bridged: _bridged.getVTables())
}
var vTables: VTableArray { VTableArray(bridgedCtxt: _bridged) }

var witnessTables: WitnessTableList {
WitnessTableList(first: _bridged.getFirstWitnessTableInModule().witnessTable)
Expand Down Expand Up @@ -131,6 +129,44 @@ struct ModulePassContext : Context, CustomStringConvertible {
return function.isDefinition
}

func specialize(function: Function, for substitutions: SubstitutionMap) -> Function? {
return _bridged.specializeFunction(function.bridged, substitutions.bridged).function
}

enum DeserializationMode {
case allFunctions
case onlySharedFunctions
}

func deserializeAllCallees(of function: Function, mode: DeserializationMode) {
_bridged.deserializeAllCallees(function.bridged, mode == .allFunctions ? true : false)
}

@discardableResult
func createWitnessTable(entries: [WitnessTable.Entry],
conformance: ProtocolConformance,
linkage: Linkage,
serialized: Bool) -> WitnessTable
{
let bridgedEntries = entries.map { $0.bridged }
let bridgedWitnessTable = bridgedEntries.withBridgedArrayRef {
_bridged.createWitnessTable(linkage.bridged, serialized, conformance.bridged, $0)
}
return WitnessTable(bridged: bridgedWitnessTable)
}

@discardableResult
func createSpecializedVTable(entries: [VTable.Entry],
for classType: Type,
isSerialized: Bool) -> VTable
{
let bridgedEntries = entries.map { $0.bridged }
let bridgedVTable = bridgedEntries.withBridgedArrayRef {
_bridged.createSpecializedVTable(classType.bridged, isSerialized, $0)
}
return VTable(bridged: bridgedVTable)
}

func createEmptyFunction(
name: String,
parameters: [ParameterInfo],
Expand Down Expand Up @@ -176,4 +212,8 @@ extension Function {
func set(linkage: Linkage, _ context: ModulePassContext) {
bridged.setLinkage(linkage.bridged)
}

func set(isSerialized: Bool, _ context: ModulePassContext) {
bridged.setIsSerialized(isSerialized)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ swift_compiler_sources(Optimizer
EscapeUtils.swift
ForwardingUtils.swift
FunctionSignatureTransforms.swift
GenericSpecialization.swift
LifetimeDependenceUtils.swift
LocalVariableUtils.swift
OptUtils.swift
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ASTBridging

import Basic
import SIL

public typealias DiagID = BridgedDiagID

Expand All @@ -29,6 +30,16 @@ extension Int: DiagnosticArgument {
fn(BridgedDiagnosticArgument(self))
}
}
extension Type: DiagnosticArgument {
public func _withBridgedDiagnosticArgument(_ fn: (BridgedDiagnosticArgument) -> Void) {
fn(bridged.asDiagnosticArgument())
}
}
extension DeclRef: DiagnosticArgument {
public func _withBridgedDiagnosticArgument(_ fn: (BridgedDiagnosticArgument) -> Void) {
fn(bridged.asDiagnosticArgument())
}
}

public struct DiagnosticFixIt {
public let start: SourceLoc
Expand Down
Loading