Skip to content

Commit 0c73b3e

Browse files
author
Marc Rasi
committed
address comments
1 parent e587c13 commit 0c73b3e

File tree

1 file changed

+5
-17
lines changed

1 file changed

+5
-17
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22762276
loc, getterVJPRef, /*substitutionMap*/ {},
22772277
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false);
22782278

2279-
// Get the VJP results (original results and pullback)
2279+
// Get the VJP results (original results and pullback).
22802280
SmallVector<SILValue, 8> vjpDirectResults;
22812281
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
22822282
ArrayRef<SILValue> originalDirectResults =
@@ -2291,6 +2291,8 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22912291

22922292
// Checkpoint the original results.
22932293
getPrimalInfo().addStaticPrimalValueDecl(sei);
2294+
getBuilder().createRetainValue(loc, originalDirectResult,
2295+
getBuilder().getDefaultAtomicity());
22942296
staticPrimalValues.push_back(originalDirectResult);
22952297

22962298
// Checkpoint the pullback.
@@ -3612,28 +3614,14 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36123614
// Construct the pullback arguments.
36133615
SmallVector<SILValue, 8> args;
36143616
auto seed = getAdjointValue(sei);
3615-
auto *seedBuf = builder.createAllocStack(loc, seed.getType());
3616-
materializeAdjointIndirectHelper(seed, seedBuf);
3617-
if (seed.getType().isAddressOnly(getModule()))
3618-
args.push_back(seedBuf);
3619-
else {
3620-
auto access = builder.createBeginAccess(
3621-
loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3622-
/*noNestedConflict*/ true,
3623-
/*fromBuiltin*/ false);
3624-
args.push_back(builder.createLoad(
3625-
loc, access, getBufferLOQ(seed.getSwiftType(), getAdjoint())));
3626-
builder.createEndAccess(loc, access, /*aborted*/ false);
3627-
}
3617+
assert(seed.getType().isObject());
3618+
args.push_back(materializeAdjointDirect(seed, loc));
36283619

36293620
// Call the pullback.
36303621
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
36313622
args, /*isNonThrowing*/ false);
36323623
assert(!pullbackCall->hasIndirectResults());
36333624

3634-
// Clean up seed allocation.
3635-
builder.createDeallocStack(loc, seedBuf);
3636-
36373625
// Set adjoint for the `struct_extract` operand.
36383626
addAdjointValue(sei->getOperand(),
36393627
AdjointValue::getMaterialized(pullbackCall));

0 commit comments

Comments
 (0)