@@ -309,6 +309,8 @@ class JVPCloner::Implementation final
309
309
// Tangent buffer mapping
310
310
// --------------------------------------------------------------------------//
311
311
312
+ // / Sets the tangent buffer for the original buffer. Asserts that the
313
+ // / original buffer does not already have a tangent buffer.
312
314
void setTangentBuffer (SILBasicBlock *origBB, SILValue originalBuffer,
313
315
SILValue tangentBuffer) {
314
316
assert (originalBuffer->getType ().isAddress ());
@@ -318,15 +320,14 @@ class JVPCloner::Implementation final
318
320
(void )insertion;
319
321
}
320
322
321
- // / Returns a tangent buffer for a provided original buffer.
323
+ // / Returns the tangent buffer for the original buffer. Asserts that the
324
+ // / original buffer has a tangent buffer.
322
325
SILValue &getTangentBuffer (SILBasicBlock *origBB, SILValue originalBuffer) {
323
326
assert (originalBuffer->getType ().isAddress ());
324
327
assert (originalBuffer->getFunction () == original);
325
- auto insertion =
326
- bufferMap.try_emplace ({origBB, originalBuffer}, SILValue ());
327
- assert (!insertion.second && " Tangent buffer should already exist" );
328
- auto &tanBuf = insertion.first ->getSecond ();
329
- return tanBuf;
328
+ auto it = bufferMap.find ({origBB, originalBuffer});
329
+ assert (it != bufferMap.end () && " Tangent buffer should already exist" );
330
+ return it->getSecond ();
330
331
}
331
332
332
333
// --------------------------------------------------------------------------//
@@ -450,8 +451,8 @@ class JVPCloner::Implementation final
450
451
void visitApplyInst (ApplyInst *ai) {
451
452
bool shouldDifferentiate =
452
453
differentialInfo.shouldDifferentiateApplySite (ai);
453
- // If the function has no active arguments and results, zero-initialize
454
- // tangent buffers for the indirect results.
454
+ // If the function has no active arguments or results, zero-initialize the
455
+ // tangent buffers of the active indirect results.
455
456
if (!shouldDifferentiate) {
456
457
for (auto indResult : ai->getIndirectSILResults ())
457
458
if (activityInfo.isActive (indResult, getIndices ())) {
@@ -733,6 +734,7 @@ class JVPCloner::Implementation final
733
734
differentialBuilder.emitDestroyAddrAndFold (loc, alloc);
734
735
differentialBuilder.createDeallocStack (loc, alloc);
735
736
}
737
+
736
738
// Return a tuple of the original result and differential.
737
739
SmallVector<SILValue, 8 > directResults;
738
740
directResults.append (origResults.begin (), origResults.end ());
@@ -818,18 +820,18 @@ class JVPCloner::Implementation final
818
820
// / Tangent: tan[y] = load tan[x]
819
821
void visitLoadInst (LoadInst *li) {
820
822
TypeSubstCloner::visitLoadInst (li);
821
- // If an active buffer is loaded ( take) to a non-active value
822
- // we have to uninitialized the buffer.
823
+ // If an active buffer is loaded with take to a non-active value, destroy
824
+ // the active buffer's tangent buffer.
823
825
if (!differentialInfo.shouldDifferentiateInstruction (li)) {
824
826
auto isTake =
825
827
(li->getOwnershipQualifier () == LoadOwnershipQualifier::Take);
826
- // Destroy `tanBuf`.
827
828
if (isTake && activityInfo.isActive (li->getOperand (), getIndices ())) {
828
829
auto &tanBuf = getTangentBuffer (li->getParent (), li->getOperand ());
829
830
getDifferentialBuilder ().emitDestroyOperation (tanBuf.getLoc (), tanBuf);
830
831
}
831
832
return ;
832
833
}
834
+ // Otherwise, do standard differential cloning.
833
835
auto &diffBuilder = getDifferentialBuilder ();
834
836
auto *bb = li->getParent ();
835
837
auto loc = li->getLoc ();
@@ -856,17 +858,17 @@ class JVPCloner::Implementation final
856
858
// / Tangent: store tan[x] to tan[y]
857
859
void visitStoreInst (StoreInst *si) {
858
860
TypeSubstCloner::visitStoreInst (si);
859
- // If a non-active value is stored into an active buffer,
860
- // we have to zero-initialized the buffer.
861
+ // If a non-active value is stored into an active buffer, zero-initialize
862
+ // the active buffer's tangent buffer.
861
863
if (!differentialInfo.shouldDifferentiateInstruction (si)) {
862
- // Zero-initialize `tanBufDest`.
863
864
if (activityInfo.isActive (si->getDest (), getIndices ())) {
864
865
auto &tanBufDest = getTangentBuffer (si->getParent (), si->getDest ());
865
866
emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
866
867
tanBufDest.getLoc ());
867
868
}
868
869
return ;
869
870
}
871
+ // Otherwise, do standard differential cloning.
870
872
auto &diffBuilder = getDifferentialBuilder ();
871
873
auto loc = si->getLoc ();
872
874
auto tanValSrc = materializeTangent (getTangentValue (si->getSrc ()), loc);
@@ -880,17 +882,17 @@ class JVPCloner::Implementation final
880
882
// / Tangent: store_borrow tan[x] to tan[y]
881
883
void visitStoreBorrowInst (StoreBorrowInst *sbi) {
882
884
TypeSubstCloner::visitStoreBorrowInst (sbi);
883
- // If a non-active value is stored into an active buffer,
884
- // we have to zero-initialized the buffer.
885
+ // If a non-active value is stored into an active buffer, zero-initialize
886
+ // the active buffer's tangent buffer.
885
887
if (!differentialInfo.shouldDifferentiateInstruction (sbi)) {
886
- // Zero-initialize `tanBufDest`.
887
888
if (activityInfo.isActive (sbi->getDest (), getIndices ())) {
888
889
auto &tanBufDest = getTangentBuffer (sbi->getParent (), sbi->getDest ());
889
890
emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
890
891
tanBufDest.getLoc ());
891
892
}
892
893
return ;
893
894
}
895
+ // Otherwise, do standard differential cloning.
894
896
auto &diffBuilder = getDifferentialBuilder ();
895
897
auto loc = sbi->getLoc ();
896
898
auto tanValSrc = materializeTangent (getTangentValue (sbi->getSrc ()), loc);
@@ -903,18 +905,16 @@ class JVPCloner::Implementation final
903
905
// / Tangent: copy_addr tan[x] to tan[y]
904
906
void visitCopyAddrInst (CopyAddrInst *cai) {
905
907
TypeSubstCloner::visitCopyAddrInst (cai);
906
- // If a non-active buffer is copied into an active buffer,
907
- // we have to zero-initialized the buffer.
908
- // If an active buffer is copied ( take) into a non-active buffer,
909
- // we have to uninitialize the buffer.
908
+ // If a non-active buffer is copied into an active buffer, zero-initialize
909
+ // the destination buffer's tangent buffer.
910
+ // If an active buffer is copied with take into a non-active buffer, destroy
911
+ // the source buffer's tangent buffer.
910
912
if (!differentialInfo.shouldDifferentiateInstruction (cai)) {
911
- // Zero-initialize `tanBufDest`.
912
913
if (activityInfo.isActive (cai->getDest (), getIndices ())) {
913
914
auto &tanBufDest = getTangentBuffer (cai->getParent (), cai->getDest ());
914
915
emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
915
916
tanBufDest.getLoc ());
916
917
}
917
- // Destroy `tanBufSrc`.
918
918
if (cai->isTakeOfSrc () &&
919
919
activityInfo.isActive (cai->getSrc (), getIndices ())) {
920
920
auto &tanBufSrc = getTangentBuffer (cai->getParent (), cai->getSrc ());
@@ -923,6 +923,7 @@ class JVPCloner::Implementation final
923
923
}
924
924
return ;
925
925
}
926
+ // Otherwise, do standard differential cloning.
926
927
auto diffBuilder = getDifferentialBuilder ();
927
928
auto loc = cai->getLoc ();
928
929
auto *bb = cai->getParent ();
@@ -1630,7 +1631,7 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1630
1631
auto paramIndex =
1631
1632
std::distance (origTy->getParameters ().begin (), &*inoutParamIt);
1632
1633
// If the original `inout` parameter is a differentiability parameter,
1633
- // then it already has a corresponding differential parameter. Skip adding
1634
+ // then it already has a corresponding differential parameter. Do not add
1634
1635
// a corresponding differential result.
1635
1636
if (indices.parameters ->contains (paramIndex))
1636
1637
continue ;
0 commit comments