@@ -560,7 +560,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
560
560
SILOptFunctionBuilder &fb, SILFunction *parentThunk,
561
561
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
562
562
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
563
- AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
563
+ const AutoDiffConfig & desiredConfig, const AutoDiffConfig & actualConfig,
564
564
ADContext &adContext) {
565
565
LLVM_DEBUG (getADDebugStream ()
566
566
<< " Getting a subset parameters thunk for " << linearMapType
@@ -592,8 +592,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
592
592
if (!thunk->empty ())
593
593
return {thunk, interfaceSubs};
594
594
595
- // TODO(TF-1206): Enable ownership in all differentiation thunks.
596
- thunk->setOwnershipEliminated ();
597
595
thunk->setGenericEnvironment (genericEnv);
598
596
auto *entry = thunk->createBasicBlock ();
599
597
TangentBuilder builder (entry, adContext);
@@ -602,6 +600,14 @@ getOrCreateSubsetParametersThunkForLinearMap(
602
600
// Get arguments.
603
601
SmallVector<SILValue, 4 > arguments;
604
602
SmallVector<AllocStackInst *, 4 > localAllocations;
603
+ SmallVector<SILValue, 4 > valuesToCleanup;
604
+ auto cleanupValues = [&]() {
605
+ for (auto value : llvm::reverse (valuesToCleanup))
606
+ builder.emitDestroyOperation (loc, value);
607
+
608
+ for (auto *alloc : llvm::reverse (localAllocations))
609
+ builder.createDeallocStack (loc, alloc);
610
+ };
605
611
606
612
// Build a `.zero` argument for the given `Differentiable`-conforming type.
607
613
auto buildZeroArgument = [&](SILType zeroSILType) {
@@ -617,10 +623,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
617
623
localAllocations.push_back (buf);
618
624
builder.emitZeroIntoBuffer (loc, buf, IsInitialization);
619
625
if (zeroSILType.isAddress ()) {
626
+ valuesToCleanup.push_back (buf);
620
627
arguments.push_back (buf);
621
628
} else {
622
629
auto arg = builder.emitLoadValueOperation (loc, buf,
623
630
LoadOwnershipQualifier::Take);
631
+ valuesToCleanup.push_back (arg);
624
632
arguments.push_back (arg);
625
633
}
626
634
break ;
@@ -739,8 +747,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
739
747
// If differential thunk, deallocate local allocations and directly return
740
748
// `apply` result.
741
749
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
742
- for (auto *alloc : llvm::reverse (localAllocations))
743
- builder.createDeallocStack (loc, alloc);
750
+ cleanupValues ();
744
751
builder.createReturn (loc, ai);
745
752
return {thunk, interfaceSubs};
746
753
}
@@ -787,8 +794,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
787
794
}
788
795
}
789
796
// Deallocate local allocations and return final direct result.
790
- for (auto *alloc : llvm::reverse (localAllocations))
791
- builder.createDeallocStack (loc, alloc);
797
+ cleanupValues ();
792
798
auto result = joinElements (results, builder, loc);
793
799
builder.createReturn (loc, result);
794
800
0 commit comments