Skip to content

Commit 9f08f5c

Browse files
committed
NFC: minor cleanup.
Stylistic edits. Fix test by removing invalid `@_silgen_name` usage.
1 parent 53b7c63 commit 9f08f5c

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ class JVPCloner::Implementation final
309309
// Tangent buffer mapping
310310
//--------------------------------------------------------------------------//
311311

312+
/// Sets the tangent buffer for the original buffer. Asserts that the
313+
/// original buffer does not already have a tangent buffer.
312314
void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
313315
SILValue tangentBuffer) {
314316
assert(originalBuffer->getType().isAddress());
@@ -318,15 +320,14 @@ class JVPCloner::Implementation final
318320
(void)insertion;
319321
}
320322

321-
/// Returns a tangent buffer for a provided original buffer.
323+
/// Returns the tangent buffer for the original buffer. Asserts that the
324+
/// original buffer has a tangent buffer.
322325
SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
323326
assert(originalBuffer->getType().isAddress());
324327
assert(originalBuffer->getFunction() == original);
325-
auto insertion =
326-
bufferMap.try_emplace({origBB, originalBuffer}, SILValue());
327-
assert(!insertion.second && "Tangent buffer should already exist");
328-
auto &tanBuf = insertion.first->getSecond();
329-
return tanBuf;
328+
auto it = bufferMap.find({origBB, originalBuffer});
329+
assert(it != bufferMap.end() && "Tangent buffer should already exist");
330+
return it->getSecond();
330331
}
331332

332333
//--------------------------------------------------------------------------//
@@ -450,8 +451,8 @@ class JVPCloner::Implementation final
450451
void visitApplyInst(ApplyInst *ai) {
451452
bool shouldDifferentiate =
452453
differentialInfo.shouldDifferentiateApplySite(ai);
453-
// If the function has no active arguments and results, zero-initialize
454-
// tangent buffers for the indirect results.
454+
// If the function has no active arguments or results, zero-initialize the
455+
// tangent buffers of the active indirect results.
455456
if (!shouldDifferentiate) {
456457
for (auto indResult : ai->getIndirectSILResults())
457458
if (activityInfo.isActive(indResult, getIndices())) {
@@ -733,6 +734,7 @@ class JVPCloner::Implementation final
733734
differentialBuilder.emitDestroyAddrAndFold(loc, alloc);
734735
differentialBuilder.createDeallocStack(loc, alloc);
735736
}
737+
736738
// Return a tuple of the original result and differential.
737739
SmallVector<SILValue, 8> directResults;
738740
directResults.append(origResults.begin(), origResults.end());
@@ -818,18 +820,18 @@ class JVPCloner::Implementation final
818820
/// Tangent: tan[y] = load tan[x]
819821
void visitLoadInst(LoadInst *li) {
820822
TypeSubstCloner::visitLoadInst(li);
821-
// If an active buffer is loaded (take) to a non-active value
822-
// we have to uninitialized the buffer.
823+
// If an active buffer is loaded with take to a non-active value, destroy
824+
// the active buffer's tangent buffer.
823825
if (!differentialInfo.shouldDifferentiateInstruction(li)) {
824826
auto isTake =
825827
(li->getOwnershipQualifier() == LoadOwnershipQualifier::Take);
826-
// Destroy `tanBuf`.
827828
if (isTake && activityInfo.isActive(li->getOperand(), getIndices())) {
828829
auto &tanBuf = getTangentBuffer(li->getParent(), li->getOperand());
829830
getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf);
830831
}
831832
return;
832833
}
834+
// Otherwise, do standard differential cloning.
833835
auto &diffBuilder = getDifferentialBuilder();
834836
auto *bb = li->getParent();
835837
auto loc = li->getLoc();
@@ -856,17 +858,17 @@ class JVPCloner::Implementation final
856858
/// Tangent: store tan[x] to tan[y]
857859
void visitStoreInst(StoreInst *si) {
858860
TypeSubstCloner::visitStoreInst(si);
859-
// If a non-active value is stored into an active buffer,
860-
// we have to zero-initialized the buffer.
861+
// If a non-active value is stored into an active buffer, zero-initialize
862+
// the active buffer's tangent buffer.
861863
if (!differentialInfo.shouldDifferentiateInstruction(si)) {
862-
// Zero-initialize `tanBufDest`.
863864
if (activityInfo.isActive(si->getDest(), getIndices())) {
864865
auto &tanBufDest = getTangentBuffer(si->getParent(), si->getDest());
865866
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
866867
tanBufDest.getLoc());
867868
}
868869
return;
869870
}
871+
// Otherwise, do standard differential cloning.
870872
auto &diffBuilder = getDifferentialBuilder();
871873
auto loc = si->getLoc();
872874
auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc);
@@ -880,17 +882,17 @@ class JVPCloner::Implementation final
880882
/// Tangent: store_borrow tan[x] to tan[y]
881883
void visitStoreBorrowInst(StoreBorrowInst *sbi) {
882884
TypeSubstCloner::visitStoreBorrowInst(sbi);
883-
// If a non-active value is stored into an active buffer,
884-
// we have to zero-initialized the buffer.
885+
// If a non-active value is stored into an active buffer, zero-initialize
886+
// the active buffer's tangent buffer.
885887
if (!differentialInfo.shouldDifferentiateInstruction(sbi)) {
886-
// Zero-initialize `tanBufDest`.
887888
if (activityInfo.isActive(sbi->getDest(), getIndices())) {
888889
auto &tanBufDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
889890
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
890891
tanBufDest.getLoc());
891892
}
892893
return;
893894
}
895+
// Otherwise, do standard differential cloning.
894896
auto &diffBuilder = getDifferentialBuilder();
895897
auto loc = sbi->getLoc();
896898
auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc);
@@ -903,18 +905,16 @@ class JVPCloner::Implementation final
903905
/// Tangent: copy_addr tan[x] to tan[y]
904906
void visitCopyAddrInst(CopyAddrInst *cai) {
905907
TypeSubstCloner::visitCopyAddrInst(cai);
906-
// If a non-active buffer is copied into an active buffer,
907-
// we have to zero-initialized the buffer.
908-
// If an active buffer is copied (take) into a non-active buffer,
909-
// we have to uninitialize the buffer.
908+
// If a non-active buffer is copied into an active buffer, zero-initialize
909+
// the destination buffer's tangent buffer.
910+
// If an active buffer is copied with take into a non-active buffer, destroy
911+
// the source buffer's tangent buffer.
910912
if (!differentialInfo.shouldDifferentiateInstruction(cai)) {
911-
// Zero-initialize `tanBufDest`.
912913
if (activityInfo.isActive(cai->getDest(), getIndices())) {
913914
auto &tanBufDest = getTangentBuffer(cai->getParent(), cai->getDest());
914915
emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest,
915916
tanBufDest.getLoc());
916917
}
917-
// Destroy `tanBufSrc`.
918918
if (cai->isTakeOfSrc() &&
919919
activityInfo.isActive(cai->getSrc(), getIndices())) {
920920
auto &tanBufSrc = getTangentBuffer(cai->getParent(), cai->getSrc());
@@ -923,6 +923,7 @@ class JVPCloner::Implementation final
923923
}
924924
return;
925925
}
926+
// Otherwise, do standard differential cloning.
926927
auto diffBuilder = getDifferentialBuilder();
927928
auto loc = cai->getLoc();
928929
auto *bb = cai->getParent();
@@ -1630,7 +1631,7 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
16301631
auto paramIndex =
16311632
std::distance(origTy->getParameters().begin(), &*inoutParamIt);
16321633
// If the original `inout` parameter is a differentiability parameter,
1633-
// then it already has a corresponding differential parameter. Skip adding
1634+
// then it already has a corresponding differential parameter. Do not add
16341635
// a corresponding differential result.
16351636
if (indices.parameters->contains(paramIndex))
16361637
continue;

test/AutoDiff/validation-test/forward_mode.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,10 +1319,9 @@ ForwardModeTests.test("ForceUnwrapping") {
13191319
expectEqual(5, forceUnwrap(Float(2)))
13201320
}
13211321

1322-
ForwardModeTests.test("NonActiveIndirectResult") {
1322+
ForwardModeTests.test("ApplyNonActiveIndirectResult") {
13231323
func identity<T: Differentiable>(_ x: T) -> T { x }
13241324

1325-
@_silgen_name("foo")
13261325
@differentiable
13271326
func applyNonactiveArgumentActiveIndirectResult(_ x: Tracked<Float>) -> Tracked<Float> {
13281327
var y = identity(0 as Tracked<Float>)

0 commit comments

Comments
 (0)