Skip to content

[6.2 🍒] Fix ownership issues with sequences of partial_apply's in AutoDiff closure specialization pass #81240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ import SILBridging

private let verbose = false

private func log(_ message: @autoclosure () -> String) {
private func log(prefix: Bool = true, _ message: @autoclosure () -> String) {
if verbose {
print("### \(message())")
debugLog(prefix: prefix, message())
}
}

Expand All @@ -128,47 +128,48 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
}

var remainingSpecializationRounds = 5
var callerModified = false

repeat {
// TODO: Names here are pretty misleading. We are looking for a place where
// the pullback closure is created (so for `partial_apply` instruction).
var callSites = gatherCallSites(in: function, context)
guard !callSites.isEmpty else {
return
}

if !callSites.isEmpty {
for callSite in callSites {
var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context)

if !alreadyExists {
context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee)
}
for callSite in callSites {
var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context)

rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)
if !alreadyExists {
context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee)
}

var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in
callSite.closureArgDescriptors
.map { $0.closure }
.forEach { deadClosures.pushIfNotVisited($0) }
}
rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context)
}

defer {
deadClosures.deinitialize()
}
var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in
callSite.closureArgDescriptors
.map { $0.closure }
.forEach { deadClosures.pushIfNotVisited($0) }
}

while let deadClosure = deadClosures.pop() {
let isDeleted = context.tryDeleteDeadClosure(closure: deadClosure as! SingleValueInstruction)
if isDeleted {
context.notifyInvalidatedStackNesting()
}
}
defer {
deadClosures.deinitialize()
}

if context.needFixStackNesting {
function.fixStackNesting(context)
while let deadClosure = deadClosures.pop() {
let isDeleted = context.tryDeleteDeadClosure(closure: deadClosure as! SingleValueInstruction)
if isDeleted {
context.notifyInvalidatedStackNesting()
}
}

callerModified = callSites.count > 0
if context.needFixStackNesting {
function.fixStackNesting(context)
}

remainingSpecializationRounds -= 1
} while callerModified && remainingSpecializationRounds > 0
} while remainingSpecializationRounds > 0
}

// =========== Top-level functions ========== //
Expand Down Expand Up @@ -503,12 +504,6 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
continue
}

// Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78847
// TODO: remove this if-statement once the underlying problem is fixed.
if callee.hasOwnership {
continue
}

if callee.isDefinedExternally {
continue
}
Expand Down Expand Up @@ -779,13 +774,13 @@ private extension SpecializationCloner {

let clonedRootClosure = builder.cloneRootClosure(representedBy: closureArgDesc, capturedArguments: clonedClosureArgs)

let (finalClonedReabstractedClosure, releasableClonedReabstractedClosures) =
let finalClonedReabstractedClosure =
builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure,
reabstractedClosure: callSite.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!,
origToClonedValueMap: origToClonedValueMap,
self.context)

let allClonedReleasableClosures = [clonedRootClosure] + releasableClonedReabstractedClosures
let allClonedReleasableClosures = [ finalClonedReabstractedClosure ];
return (finalClonedReabstractedClosure, allClonedReleasableClosures)
}

Expand Down Expand Up @@ -935,10 +930,9 @@ private extension Builder {

func cloneRootClosureReabstractions(rootClosure: Value, clonedRootClosure: Value, reabstractedClosure: Value,
origToClonedValueMap: [HashableValue: Value], _ context: FunctionPassContext)
-> (finalClonedReabstractedClosure: SingleValueInstruction, releasableClonedReabstractedClosures: [PartialApplyInst])
-> SingleValueInstruction
{
func inner(_ rootClosure: Value, _ clonedRootClosure: Value, _ reabstractedClosure: Value,
_ releasableClonedReabstractedClosures: inout [PartialApplyInst],
_ origToClonedValueMap: inout [HashableValue: Value]) -> Value {
switch reabstractedClosure {
case let reabstractedClosure where reabstractedClosure == rootClosure:
Expand All @@ -947,23 +941,23 @@ private extension Builder {

case let cvt as ConvertFunctionInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction,
&releasableClonedReabstractedClosures, &origToClonedValueMap)
&origToClonedValueMap)
let reabstracted = self.createConvertFunction(originalFunction: toBeReabstracted, resultType: cvt.type,
withoutActuallyEscaping: cvt.withoutActuallyEscaping)
origToClonedValueMap[cvt] = reabstracted
return reabstracted

case let cvt as ConvertEscapeToNoEscapeInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction,
&releasableClonedReabstractedClosures, &origToClonedValueMap)
&origToClonedValueMap)
let reabstracted = self.createConvertEscapeToNoEscape(originalFunction: toBeReabstracted, resultType: cvt.type,
isLifetimeGuaranteed: true)
origToClonedValueMap[cvt] = reabstracted
return reabstracted

case let pai as PartialApplyInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, pai.arguments[0],
&releasableClonedReabstractedClosures, &origToClonedValueMap)
&origToClonedValueMap)

guard let function = pai.referencedFunction else {
log("Parent function of callSite: \(rootClosure.parentFunction)")
Expand All @@ -978,13 +972,11 @@ private extension Builder {
calleeConvention: pai.calleeConvention,
hasUnknownResultIsolation: pai.hasUnknownResultIsolation,
isOnStack: pai.isOnStack)
releasableClonedReabstractedClosures.append(reabstracted)
origToClonedValueMap[pai] = reabstracted
return reabstracted

case let mdi as MarkDependenceInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &releasableClonedReabstractedClosures,
&origToClonedValueMap)
let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &origToClonedValueMap)
let base = origToClonedValueMap[mdi.base]!
let reabstracted = self.createMarkDependence(value: toBeReabstracted, base: base, kind: .Escaping)
origToClonedValueMap[mdi] = reabstracted
Expand All @@ -998,11 +990,10 @@ private extension Builder {
}
}

var releasableClonedReabstractedClosures: [PartialApplyInst] = []
var origToClonedValueMap = origToClonedValueMap
let finalClonedReabstractedClosure = inner(rootClosure, clonedRootClosure, reabstractedClosure,
&releasableClonedReabstractedClosures, &origToClonedValueMap)
return (finalClonedReabstractedClosure as! SingleValueInstruction, releasableClonedReabstractedClosures)
&origToClonedValueMap)
return (finalClonedReabstractedClosure as! SingleValueInstruction)
}

func destroyPartialApply(pai: PartialApplyInst, _ context: FunctionPassContext){
Expand Down
Loading