@@ -1392,12 +1392,14 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1392
1392
1393
1393
auto thunkTy = thunk->getLoweredFunctionType ();
1394
1394
auto thunkResult = thunkTy->getSingleResult ();
1395
- if (auto resultFnTy = thunkResult.getInterfaceType ()->getAs <SILFunctionType>()) {
1396
- // Construct new curry thunk type with `@differentiable` result.
1397
- auto diffableResultFnTy = resultFnTy->getWithExtInfo (
1398
- resultFnTy->getExtInfo ()
1399
- .withDifferentiabilityKind (DifferentiabilityKind::Normal));
1400
- auto newThunkResult = thunkResult.getWithInterfaceType (diffableResultFnTy);
1395
+ if (auto resultFnTy =
1396
+ thunkResult.getInterfaceType ()->getAs <SILFunctionType>()) {
1397
+ // Construct new curry thunk type with `@differentiable` function
1398
+ // result.
1399
+ auto diffResultFnTy = resultFnTy->getWithExtInfo (
1400
+ resultFnTy->getExtInfo ().withDifferentiabilityKind (
1401
+ DifferentiabilityKind::Normal));
1402
+ auto newThunkResult = thunkResult.getWithInterfaceType (diffResultFnTy);
1401
1403
auto thunkType = SILFunctionType::get (
1402
1404
thunkTy->getSubstGenericSignature (), thunkTy->getExtInfo (),
1403
1405
thunkTy->getCoroutineKind (), thunkTy->getCalleeConvention (),
@@ -1425,12 +1427,18 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1425
1427
cloner.run ();
1426
1428
auto *retInst =
1427
1429
cast<ReturnInst>(newThunk->findReturnBB ()->getTerminator ());
1428
- SILBuilder thunkBuilder (retInst);
1429
- auto *dfi = context.createDifferentiableFunction (thunkBuilder, loc,
1430
- parameterIndices,
1431
- retInst->getOperand ());
1430
+ auto returnValue = retInst->getOperand ();
1431
+ // Create `differentiable_function` instruction directly after the
1432
+ // defining instruction (e.g. `partial_apply`) of the returned value.
1433
+ // Note: `differentiable_function` is not created at the end of the
1434
+ // new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues.
1435
+ SILBuilder dfiBuilder (
1436
+ std::next (returnValue->getDefiningInstruction ()->getIterator ()));
1437
+ auto *dfi = context.createDifferentiableFunction (
1438
+ dfiBuilder, loc, parameterIndices, returnValue);
1432
1439
context.setResultIndex (dfi, resultIndex);
1433
- thunkBuilder.createReturn (loc, dfi);
1440
+ dfiBuilder.setInsertionPoint (newThunk->findReturnBB ());
1441
+ dfiBuilder.createReturn (loc, dfi);
1434
1442
retInst->eraseFromParent ();
1435
1443
1436
1444
context.recordGeneratedFunction (newThunk);
@@ -1450,12 +1458,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1450
1458
auto *newApply = builder.createApply (
1451
1459
ai->getLoc (), newThunkRef, ai->getSubstitutionMap (), newArgs,
1452
1460
ai->isNonThrowing ());
1453
- for (auto arg : newArgsToDestroy) {
1454
- if (arg->getType ().isObject ())
1455
- builder.emitDestroyValueOperation (loc, arg);
1456
- else
1457
- builder.emitDestroyAddr (loc, arg);
1458
- }
1461
+ for (auto arg : newArgsToDestroy)
1462
+ builder.emitDestroyOperation (loc, arg);
1459
1463
for (auto *alloc : newBuffersToDealloc)
1460
1464
builder.createDeallocStack (loc, alloc);
1461
1465
return newApply;
0 commit comments