@@ -2276,7 +2276,7 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2276
2276
loc, getterVJPRef, /* substitutionMap*/ {},
2277
2277
/* args*/ {getMappedValue (sei->getOperand ())}, /* isNonThrowing*/ false );
2278
2278
2279
- // Get the VJP results (original results and pullback)
2279
+ // Get the VJP results (original results and pullback).
2280
2280
SmallVector<SILValue, 8 > vjpDirectResults;
2281
2281
extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
2282
2282
ArrayRef<SILValue> originalDirectResults =
@@ -2291,6 +2291,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2291
2291
2292
2292
// Checkpoint the original results.
2293
2293
getPrimalInfo ().addStaticPrimalValueDecl (sei);
2294
+ getBuilder ().createRetainValue (loc, originalDirectResult,
2295
+ getBuilder ().getDefaultAtomicity ());
2294
2296
staticPrimalValues.push_back (originalDirectResult);
2295
2297
2296
2298
// Checkpoint the pullback.
@@ -3612,28 +3614,14 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3612
3614
// Construct the pullback arguments.
3613
3615
SmallVector<SILValue, 8 > args;
3614
3616
auto seed = getAdjointValue (sei);
3615
- auto *seedBuf = builder.createAllocStack (loc, seed.getType ());
3616
- materializeAdjointIndirectHelper (seed, seedBuf);
3617
- if (seed.getType ().isAddressOnly (getModule ()))
3618
- args.push_back (seedBuf);
3619
- else {
3620
- auto access = builder.createBeginAccess (
3621
- loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3622
- /* noNestedConflict*/ true ,
3623
- /* fromBuiltin*/ false );
3624
- args.push_back (builder.createLoad (
3625
- loc, access, getBufferLOQ (seed.getSwiftType (), getAdjoint ())));
3626
- builder.createEndAccess (loc, access, /* aborted*/ false );
3627
- }
3617
+ assert (seed.getType ().isObject ());
3618
+ args.push_back (materializeAdjointDirect (seed, loc));
3628
3619
3629
3620
// Call the pullback.
3630
3621
auto *pullbackCall = builder.createApply (loc, pullback, SubstitutionMap (),
3631
3622
args, /* isNonThrowing*/ false );
3632
3623
assert (!pullbackCall->hasIndirectResults ());
3633
3624
3634
- // Clean up seed allocation.
3635
- builder.createDeallocStack (loc, seedBuf);
3636
-
3637
3625
// Set adjoint for the `struct_extract` operand.
3638
3626
addAdjointValue (sei->getOperand (),
3639
3627
AdjointValue::getMaterialized (pullbackCall));
0 commit comments