@@ -1837,18 +1837,15 @@ static SILValue
1837
1837
reapplyFunctionConversion (SILValue newFunc, SILValue oldFunc,
1838
1838
SILValue oldConvertedFunc, SILBuilder &builder,
1839
1839
SILLocation loc,
1840
- GenericSignature* newFuncGenSig = nullptr ,
1841
- std::function<SILValue(SILValue)> substituteOperand =
1842
- [](SILValue v) { return v; }) {
1840
+ GenericSignature *newFuncGenSig = nullptr ) {
1843
1841
// If the old func is the new func, then there's no conversion.
1844
1842
if (oldFunc == oldConvertedFunc)
1845
1843
return newFunc;
1846
1844
// Handle a few instruction cases.
1847
1845
// thin_to_thick_function
1848
1846
if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) {
1849
1847
auto innerNewFunc = reapplyFunctionConversion (
1850
- newFunc, oldFunc, tttfi->getOperand (), builder, loc, newFuncGenSig,
1851
- substituteOperand);
1848
+ newFunc, oldFunc, tttfi->getOperand (), builder, loc, newFuncGenSig);
1852
1849
auto operandFnTy = innerNewFunc->getType ().castTo <SILFunctionType>();
1853
1850
auto thickTy = operandFnTy->getWithRepresentation (
1854
1851
SILFunctionTypeRepresentation::Thick);
@@ -1860,11 +1857,17 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1860
1857
if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
1861
1858
SmallVector<SILValue, 8 > newArgs;
1862
1859
newArgs.reserve (pai->getNumArguments ());
1863
- for (auto arg : pai->getArguments ())
1864
- newArgs.push_back (substituteOperand (arg));
1860
+ for (auto arg : pai->getArguments ()) {
1861
+ // Retain the argument since it's to be owned by the newly created
1862
+ // closure.
1863
+ if (arg->getType ().isObject ())
1864
+ builder.createRetainValue (loc, arg, builder.getDefaultAtomicity ());
1865
+ else if (arg->getType ().isLoadable (builder.getFunction ()))
1866
+ builder.createRetainValueAddr (loc, arg, builder.getDefaultAtomicity ());
1867
+ newArgs.push_back (arg);
1868
+ }
1865
1869
auto innerNewFunc = reapplyFunctionConversion (
1866
- newFunc, oldFunc, pai->getCallee (), builder, loc, newFuncGenSig,
1867
- substituteOperand);
1870
+ newFunc, oldFunc, pai->getCallee (), builder, loc, newFuncGenSig);
1868
1871
// If new function's generic signature is specified, use it to create
1869
1872
// substitution map for reapplied `partial_apply` instruction.
1870
1873
auto substMap = !newFuncGenSig
@@ -1879,8 +1882,7 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1879
1882
if (auto *cetn = dyn_cast<ConvertEscapeToNoEscapeInst>(oldConvertedFunc)) {
1880
1883
auto innerNewFunc = reapplyFunctionConversion (newFunc, oldFunc,
1881
1884
cetn->getOperand (), builder,
1882
- loc, newFuncGenSig,
1883
- substituteOperand);
1885
+ loc, newFuncGenSig);
1884
1886
auto operandFnTy = innerNewFunc->getType ().castTo <SILFunctionType>();
1885
1887
auto noEscapeType = operandFnTy->getWithExtInfo (
1886
1888
operandFnTy->getExtInfo ().withNoEscape ());
@@ -1899,8 +1901,7 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1899
1901
cfi->getOperand ()->getType ().castTo <SILFunctionType>();
1900
1902
auto innerNewFunc = reapplyFunctionConversion (newFunc, oldFunc,
1901
1903
cfi->getOperand (), builder,
1902
- loc, newFuncGenSig,
1903
- substituteOperand);
1904
+ loc, newFuncGenSig);
1904
1905
// Match a conversion from escaping to `@noescape`
1905
1906
CanSILFunctionType targetType;
1906
1907
if (!origSourceFnTy->isNoEscape () && origTargetFnTy->isNoEscape () &&
@@ -3205,7 +3206,7 @@ class VJPEmitter final
3205
3206
}
3206
3207
}
3207
3208
vjpValue = builder.createAutoDiffFunctionExtract (
3208
- original. getLoc () , AutoDiffFunctionExtractInst::Extractee::VJP,
3209
+ loc , AutoDiffFunctionExtractInst::Extractee::VJP,
3209
3210
/* differentiationOrder*/ 1 , functionSource);
3210
3211
}
3211
3212
@@ -3234,6 +3235,7 @@ class VJPEmitter final
3234
3235
// on the remapped original function operand and `autodiff_function_extract`
3235
3236
// the VJP. The actual JVP/VJP functions will be populated in the
3236
3237
// `autodiff_function` during the transform main loop.
3238
+ SILValue differentiableFunc;
3237
3239
if (!vjpValue) {
3238
3240
// FIXME: Handle indirect differentiation invokers. This may require some
3239
3241
// redesign: currently, each original function + attribute pair is mapped
@@ -3251,7 +3253,9 @@ class VJPEmitter final
3251
3253
// In the VJP, specialization is also necessary for parity. The original
3252
3254
// function operand is specialized with a remapped version of same
3253
3255
// substitution map using an argument-less `partial_apply`.
3254
- if (!ai->getSubstitutionMap ().empty ()) {
3256
+ if (ai->getSubstitutionMap ().empty ()) {
3257
+ builder.createRetainValue (loc, original, builder.getDefaultAtomicity ());
3258
+ } else {
3255
3259
auto substMap = getOpSubstitutionMap (ai->getSubstitutionMap ());
3256
3260
auto vjpPartialApply = getBuilder ().createPartialApply (
3257
3261
ai->getLoc (), original, substMap, {},
@@ -3262,6 +3266,7 @@ class VJPEmitter final
3262
3266
auto *autoDiffFuncInst = context.createAutoDiffFunction (
3263
3267
getBuilder (), loc, indices.parameters , /* differentiationOrder*/ 1 ,
3264
3268
original);
3269
+ differentiableFunc = autoDiffFuncInst;
3265
3270
3266
3271
// Record the `autodiff_function` instruction.
3267
3272
context.getAutoDiffFunctionInsts ().push_back (autoDiffFuncInst);
@@ -3296,6 +3301,11 @@ class VJPEmitter final
3296
3301
vjpArgs, ai->isNonThrowing ());
3297
3302
LLVM_DEBUG (getADDebugStream () << " Applied vjp function\n " << *vjpCall);
3298
3303
3304
+ // Release the differentiable function.
3305
+ if (differentiableFunc)
3306
+ builder.createReleaseValue (loc, differentiableFunc,
3307
+ builder.getDefaultAtomicity ());
3308
+
3299
3309
// Get the VJP results (original results and pullback).
3300
3310
SmallVector<SILValue, 8 > vjpDirectResults;
3301
3311
extractAllElements (vjpCall, getBuilder (), vjpDirectResults);
@@ -6365,7 +6375,6 @@ SILValue ADContext::promoteToDifferentiableFunction(
6365
6375
loc, assocFn, SILType::getPrimitiveObjectType (expectedAssocFnTy));
6366
6376
}
6367
6377
6368
- builder.createRetainValue (loc, assocFn, builder.getDefaultAtomicity ());
6369
6378
assocFns.push_back (assocFn);
6370
6379
}
6371
6380
@@ -6384,6 +6393,8 @@ SILValue ADContext::promoteToDifferentiableFunction(
6384
6393
// /
6385
6394
// / Folding can be disabled by the `SkipFoldingAutoDiffFunctionExtraction` flag
6386
6395
// / for SIL testing purposes.
6396
+ // FIXME: This function is not correctly detecting the foldable pattern and
6397
+ // needs to be rewritten.
6387
6398
void ADContext::foldAutoDiffFunctionExtraction (AutoDiffFunctionInst *source) {
6388
6399
// Iterate through all `autodiff_function` instruction uses.
6389
6400
for (auto use : source->getUses ()) {
0 commit comments