Skip to content

Commit eca2801

Browse files
authored
[SR-13929][AutoDiff]: Enable [ossa] for all differentiation-generated thunks (swiftlang#37054)
* [SR-13929][AutoDiff]: Enable [ossa] for Differentiation/Thunk.cpp:getOrCreateSubsetParametersThunkForLinearMap and promoteCurryThunkApplicationToDifferentiableFunction
1 parent 9df8944 commit eca2801

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

include/swift/SILOptimizer/Differentiation/Thunk.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
121121
SILOptFunctionBuilder &fb, SILFunction *assocFn,
122122
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
123123
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
124-
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
124+
const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig,
125125
ADContext &adContext);
126126

127127
} // end namespace autodiff

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
560560
SILOptFunctionBuilder &fb, SILFunction *parentThunk,
561561
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
562562
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
563-
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
563+
const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig,
564564
ADContext &adContext) {
565565
LLVM_DEBUG(getADDebugStream()
566566
<< "Getting a subset parameters thunk for " << linearMapType
@@ -592,8 +592,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
592592
if (!thunk->empty())
593593
return {thunk, interfaceSubs};
594594

595-
// TODO(TF-1206): Enable ownership in all differentiation thunks.
596-
thunk->setOwnershipEliminated();
597595
thunk->setGenericEnvironment(genericEnv);
598596
auto *entry = thunk->createBasicBlock();
599597
TangentBuilder builder(entry, adContext);
@@ -602,6 +600,14 @@ getOrCreateSubsetParametersThunkForLinearMap(
602600
// Get arguments.
603601
SmallVector<SILValue, 4> arguments;
604602
SmallVector<AllocStackInst *, 4> localAllocations;
603+
SmallVector<SILValue, 4> valuesToCleanup;
604+
auto cleanupValues = [&]() {
605+
for (auto value : llvm::reverse(valuesToCleanup))
606+
builder.emitDestroyOperation(loc, value);
607+
608+
for (auto *alloc : llvm::reverse(localAllocations))
609+
builder.createDeallocStack(loc, alloc);
610+
};
605611

606612
// Build a `.zero` argument for the given `Differentiable`-conforming type.
607613
auto buildZeroArgument = [&](SILType zeroSILType) {
@@ -617,10 +623,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
617623
localAllocations.push_back(buf);
618624
builder.emitZeroIntoBuffer(loc, buf, IsInitialization);
619625
if (zeroSILType.isAddress()) {
626+
valuesToCleanup.push_back(buf);
620627
arguments.push_back(buf);
621628
} else {
622629
auto arg = builder.emitLoadValueOperation(loc, buf,
623630
LoadOwnershipQualifier::Take);
631+
valuesToCleanup.push_back(arg);
624632
arguments.push_back(arg);
625633
}
626634
break;
@@ -739,8 +747,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
739747
// If differential thunk, deallocate local allocations and directly return
740748
// `apply` result.
741749
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
742-
for (auto *alloc : llvm::reverse(localAllocations))
743-
builder.createDeallocStack(loc, alloc);
750+
cleanupValues();
744751
builder.createReturn(loc, ai);
745752
return {thunk, interfaceSubs};
746753
}
@@ -787,8 +794,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
787794
}
788795
}
789796
// Deallocate local allocations and return final direct result.
790-
for (auto *alloc : llvm::reverse(localAllocations))
791-
builder.createDeallocStack(loc, alloc);
797+
cleanupValues();
792798
auto result = joinElements(results, builder, loc);
793799
builder.createReturn(loc, result);
794800

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,7 @@ static SILValue promoteCurryThunkApplicationToDifferentiableFunction(
10431043
if (newThunk->empty()) {
10441044
if (auto newThunkGenSig = thunkType->getSubstGenericSignature())
10451045
newThunk->setGenericEnvironment(newThunkGenSig->getGenericEnvironment());
1046-
// TODO(TF-1206): Enable ownership in all differentiation thunks.
1047-
newThunk->setOwnershipEliminated();
1046+
10481047
BasicTypeSubstCloner cloner(thunk, newThunk);
10491048
cloner.cloneFunction();
10501049
auto *retInst = cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());

0 commit comments

Comments
 (0)