@@ -1892,18 +1892,15 @@ static SILValue
1892
1892
reapplyFunctionConversion (SILValue newFunc, SILValue oldFunc,
1893
1893
SILValue oldConvertedFunc, SILBuilder &builder,
1894
1894
SILLocation loc,
1895
- GenericSignature* newFuncGenSig = nullptr ,
1896
- std::function<SILValue(SILValue)> substituteOperand =
1897
- [](SILValue v) { return v; }) {
1895
+ GenericSignature *newFuncGenSig = nullptr ) {
1898
1896
// If the old func is the new func, then there's no conversion.
1899
1897
if (oldFunc == oldConvertedFunc)
1900
1898
return newFunc;
1901
1899
// Handle a few instruction cases.
1902
1900
// thin_to_thick_function
1903
1901
if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) {
1904
1902
auto innerNewFunc = reapplyFunctionConversion (
1905
- newFunc, oldFunc, tttfi->getOperand (), builder, loc, newFuncGenSig,
1906
- substituteOperand);
1903
+ newFunc, oldFunc, tttfi->getOperand (), builder, loc, newFuncGenSig);
1907
1904
auto operandFnTy = innerNewFunc->getType ().castTo <SILFunctionType>();
1908
1905
auto thickTy = operandFnTy->getWithRepresentation (
1909
1906
SILFunctionTypeRepresentation::Thick);
@@ -1915,11 +1912,17 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1915
1912
if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
1916
1913
SmallVector<SILValue, 8 > newArgs;
1917
1914
newArgs.reserve (pai->getNumArguments ());
1918
- for (auto arg : pai->getArguments ())
1919
- newArgs.push_back (substituteOperand (arg));
1915
+ for (auto arg : pai->getArguments ()) {
1916
+ // Retain the argument since it's to be owned by the newly created
1917
+ // closure.
1918
+ if (arg->getType ().isObject ())
1919
+ builder.createRetainValue (loc, arg, builder.getDefaultAtomicity ());
1920
+ else if (arg->getType ().isLoadable (builder.getFunction ()))
1921
+ builder.createRetainValueAddr (loc, arg, builder.getDefaultAtomicity ());
1922
+ newArgs.push_back (arg);
1923
+ }
1920
1924
auto innerNewFunc = reapplyFunctionConversion (
1921
- newFunc, oldFunc, pai->getCallee (), builder, loc, newFuncGenSig,
1922
- substituteOperand);
1925
+ newFunc, oldFunc, pai->getCallee (), builder, loc, newFuncGenSig);
1923
1926
// If new function's generic signature is specified, use it to create
1924
1927
// substitution map for reapplied `partial_apply` instruction.
1925
1928
auto substMap = !newFuncGenSig
@@ -1934,8 +1937,7 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1934
1937
if (auto *cetn = dyn_cast<ConvertEscapeToNoEscapeInst>(oldConvertedFunc)) {
1935
1938
auto innerNewFunc = reapplyFunctionConversion (newFunc, oldFunc,
1936
1939
cetn->getOperand (), builder,
1937
- loc, newFuncGenSig,
1938
- substituteOperand);
1940
+ loc, newFuncGenSig);
1939
1941
auto operandFnTy = innerNewFunc->getType ().castTo <SILFunctionType>();
1940
1942
auto noEscapeType = operandFnTy->getWithExtInfo (
1941
1943
operandFnTy->getExtInfo ().withNoEscape ());
@@ -1954,8 +1956,7 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1954
1956
cfi->getOperand ()->getType ().castTo <SILFunctionType>();
1955
1957
auto innerNewFunc = reapplyFunctionConversion (newFunc, oldFunc,
1956
1958
cfi->getOperand (), builder,
1957
- loc, newFuncGenSig,
1958
- substituteOperand);
1959
+ loc, newFuncGenSig);
1959
1960
// Match a conversion from escaping to `@noescape`
1960
1961
CanSILFunctionType targetType;
1961
1962
if (!origSourceFnTy->isNoEscape () && origTargetFnTy->isNoEscape () &&
@@ -3260,7 +3261,7 @@ class VJPEmitter final
3260
3261
}
3261
3262
}
3262
3263
vjpValue = builder.createAutoDiffFunctionExtract (
3263
- original. getLoc () , AutoDiffFunctionExtractInst::Extractee::VJP,
3264
+ loc , AutoDiffFunctionExtractInst::Extractee::VJP,
3264
3265
/* differentiationOrder*/ 1 , functionSource);
3265
3266
}
3266
3267
@@ -3289,6 +3290,7 @@ class VJPEmitter final
3289
3290
// on the remapped original function operand and `autodiff_function_extract`
3290
3291
// the VJP. The actual JVP/VJP functions will be populated in the
3291
3292
// `autodiff_function` during the transform main loop.
3293
+ SILValue differentiableFunc;
3292
3294
if (!vjpValue) {
3293
3295
// FIXME: Handle indirect differentiation invokers. This may require some
3294
3296
// redesign: currently, each original function + attribute pair is mapped
@@ -3306,7 +3308,9 @@ class VJPEmitter final
3306
3308
// In the VJP, specialization is also necessary for parity. The original
3307
3309
// function operand is specialized with a remapped version of same
3308
3310
// substitution map using an argument-less `partial_apply`.
3309
- if (!ai->getSubstitutionMap ().empty ()) {
3311
+ if (ai->getSubstitutionMap ().empty ()) {
3312
+ builder.createRetainValue (loc, original, builder.getDefaultAtomicity ());
3313
+ } else {
3310
3314
auto substMap = getOpSubstitutionMap (ai->getSubstitutionMap ());
3311
3315
auto vjpPartialApply = getBuilder ().createPartialApply (
3312
3316
ai->getLoc (), original, substMap, {},
@@ -3317,6 +3321,7 @@ class VJPEmitter final
3317
3321
auto *autoDiffFuncInst = context.createAutoDiffFunction (
3318
3322
getBuilder (), loc, indices.parameters , /* differentiationOrder*/ 1 ,
3319
3323
original);
3324
+ differentiableFunc = autoDiffFuncInst;
3320
3325
3321
3326
// Record the `autodiff_function` instruction.
3322
3327
context.getAutoDiffFunctionInsts ().push_back (autoDiffFuncInst);
@@ -3351,6 +3356,11 @@ class VJPEmitter final
3351
3356
vjpArgs, ai->isNonThrowing ());
3352
3357
LLVM_DEBUG (getADDebugStream () << " Applied vjp function\n " << *vjpCall);
3353
3358
3359
+ // Release the differentiable function.
3360
+ if (differentiableFunc)
3361
+ builder.createReleaseValue (loc, differentiableFunc,
3362
+ builder.getDefaultAtomicity ());
3363
+
3354
3364
// Get the VJP results (original results and pullback).
3355
3365
SmallVector<SILValue, 8 > vjpDirectResults;
3356
3366
extractAllElements (vjpCall, getBuilder (), vjpDirectResults);
@@ -6566,7 +6576,6 @@ SILValue ADContext::promoteToDifferentiableFunction(
6566
6576
loc, assocFn, SILType::getPrimitiveObjectType (expectedAssocFnTy));
6567
6577
}
6568
6578
6569
- builder.createRetainValue (loc, assocFn, builder.getDefaultAtomicity ());
6570
6579
assocFns.push_back (assocFn);
6571
6580
}
6572
6581
@@ -6585,6 +6594,8 @@ SILValue ADContext::promoteToDifferentiableFunction(
6585
6594
// /
6586
6595
// / Folding can be disabled by the `SkipFoldingAutoDiffFunctionExtraction` flag
6587
6596
// / for SIL testing purposes.
6597
+ // FIXME: This function is not correctly detecting the foldable pattern and
6598
+ // needs to be rewritten.
6588
6599
void ADContext::foldAutoDiffFunctionExtraction (AutoDiffFunctionInst *source) {
6589
6600
// Iterate through all `autodiff_function` instruction uses.
6590
6601
for (auto use : source->getUses ()) {
0 commit comments