@@ -134,19 +134,6 @@ class JVPCloner::Implementation final
134
134
// General utilities
135
135
// --------------------------------------------------------------------------//
136
136
137
- SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint () {
138
- // If there are no local allocations, insert at the beginning of the tangent
139
- // entry.
140
- if (differentialLocalAllocations.empty ())
141
- return getDifferential ().getEntryBlock ()->begin ();
142
- // Otherwise, insert before the last local allocation. Inserting before
143
- // rather than after ensures that allocation and zero initialization
144
- // instructions are grouped together.
145
- auto lastLocalAlloc = differentialLocalAllocations.back ();
146
- auto it = lastLocalAlloc->getDefiningInstruction ()->getIterator ();
147
- return it;
148
- }
149
-
150
137
// / Get the lowered SIL type of the given AST type.
151
138
SILType getLoweredType (Type type) {
152
139
auto jvpGenSig = jvp->getLoweredFunctionType ()->getSubstGenericSignature ();
@@ -309,6 +296,8 @@ class JVPCloner::Implementation final
309
296
// Tangent buffer mapping
310
297
// --------------------------------------------------------------------------//
311
298
299
+ // / Sets the tangent buffer for the original buffer. Asserts that the
300
+ // / original buffer does not already have a tangent buffer.
312
301
void setTangentBuffer (SILBasicBlock *origBB, SILValue originalBuffer,
313
302
SILValue tangentBuffer) {
314
303
assert (originalBuffer->getType ().isAddress ());
@@ -318,13 +307,14 @@ class JVPCloner::Implementation final
318
307
(void )insertion;
319
308
}
320
309
310
+ // / Returns the tangent buffer for the original buffer. Asserts that the
311
+ // / original buffer has a tangent buffer.
321
312
SILValue &getTangentBuffer (SILBasicBlock *origBB, SILValue originalBuffer) {
322
313
assert (originalBuffer->getType ().isAddress ());
323
314
assert (originalBuffer->getFunction () == original);
324
- auto insertion =
325
- bufferMap.try_emplace ({origBB, originalBuffer}, SILValue ());
326
- assert (!insertion.second && " Tangent buffer should already exist" );
327
- return insertion.first ->getSecond ();
315
+ auto it = bufferMap.find ({origBB, originalBuffer});
316
+ assert (it != bufferMap.end () && " Tangent buffer should already exist" );
317
+ return it->getSecond ();
328
318
}
329
319
330
320
// --------------------------------------------------------------------------//
@@ -446,9 +436,21 @@ class JVPCloner::Implementation final
446
436
// If an `apply` has active results or active inout parameters, replace it
447
437
// with an `apply` of its JVP.
448
438
void visitApplyInst (ApplyInst *ai) {
439
+ bool shouldDifferentiate =
440
+ differentialInfo.shouldDifferentiateApplySite (ai);
441
+ // If the function has no active arguments or results, zero-initialize the
442
+ // tangent buffers of the active indirect results.
443
+ if (!shouldDifferentiate) {
444
+ for (auto indResult : ai->getIndirectSILResults ())
445
+ if (activityInfo.isActive (indResult, getIndices ())) {
446
+ auto &tanBuf = getTangentBuffer (ai->getParent (), indResult);
447
+ emitZeroIndirect (tanBuf->getType ().getASTType (), tanBuf,
448
+ tanBuf.getLoc ());
449
+ }
450
+ }
449
451
// If the function should not be differentiated or its the array literal
450
452
// initialization intrinsic, just do standard cloning.
451
- if (!differentialInfo. shouldDifferentiateApplySite (ai) ||
453
+ if (!shouldDifferentiate ||
452
454
ArraySemanticsCall (ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
453
455
LLVM_DEBUG (getADDebugStream () << " No active results:\n " << *ai << ' \n ' );
454
456
TypeSubstCloner::visitApplyInst (ai);
@@ -789,7 +791,7 @@ class JVPCloner::Implementation final
789
791
auto &diffBuilder = getDifferentialBuilder ();
790
792
auto loc = dvi->getLoc ();
791
793
auto tanVal = materializeTangent (getTangentValue (dvi->getOperand ()), loc);
792
- diffBuilder.emitDestroyValue (loc, tanVal);
794
+ diffBuilder.emitDestroyValueOperation (loc, tanVal);
793
795
}
794
796
795
797
CLONE_AND_EMIT_TANGENT (CopyValue, cvi) {
@@ -804,7 +806,20 @@ class JVPCloner::Implementation final
804
806
// / Handle `load` instruction.
805
807
// / Original: y = load x
806
808
// / Tangent: tan[y] = load tan[x]
807
- CLONE_AND_EMIT_TANGENT (Load, li) {
809
+ void visitLoadInst (LoadInst *li) {
810
+ TypeSubstCloner::visitLoadInst (li);
811
+ // If an active buffer is loaded with take to a non-active value, destroy
812
+ // the active buffer's tangent buffer.
813
+ if (!differentialInfo.shouldDifferentiateInstruction (li)) {
814
+ auto isTake =
815
+ (li->getOwnershipQualifier () == LoadOwnershipQualifier::Take);
816
+ if (isTake && activityInfo.isActive (li->getOperand (), getIndices ())) {
817
+ auto &tanBuf = getTangentBuffer (li->getParent (), li->getOperand ());
818
+ getDifferentialBuilder ().emitDestroyOperation (tanBuf.getLoc (), tanBuf);
819
+ }
820
+ return ;
821
+ }
822
+ // Otherwise, do standard differential cloning.
808
823
auto &diffBuilder = getDifferentialBuilder ();
809
824
auto *bb = li->getParent ();
810
825
auto loc = li->getLoc ();
@@ -829,7 +844,19 @@ class JVPCloner::Implementation final
829
844
// / Handle `store` instruction in the differential.
830
845
// / Original: store x to y
831
846
// / Tangent: store tan[x] to tan[y]
832
- CLONE_AND_EMIT_TANGENT (Store, si) {
847
+ void visitStoreInst (StoreInst *si) {
848
+ TypeSubstCloner::visitStoreInst (si);
849
+ // If a non-active value is stored into an active buffer, zero-initialize
850
+ // the active buffer's tangent buffer.
851
+ if (!differentialInfo.shouldDifferentiateInstruction (si)) {
852
+ if (activityInfo.isActive (si->getDest (), getIndices ())) {
853
+ auto &tanBufDest = getTangentBuffer (si->getParent (), si->getDest ());
854
+ emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
855
+ tanBufDest.getLoc ());
856
+ }
857
+ return ;
858
+ }
859
+ // Otherwise, do standard differential cloning.
833
860
auto &diffBuilder = getDifferentialBuilder ();
834
861
auto loc = si->getLoc ();
835
862
auto tanValSrc = materializeTangent (getTangentValue (si->getSrc ()), loc);
@@ -841,7 +868,19 @@ class JVPCloner::Implementation final
841
868
// / Handle `store_borrow` instruction in the differential.
842
869
// / Original: store_borrow x to y
843
870
// / Tangent: store_borrow tan[x] to tan[y]
844
- CLONE_AND_EMIT_TANGENT (StoreBorrow, sbi) {
871
+ void visitStoreBorrowInst (StoreBorrowInst *sbi) {
872
+ TypeSubstCloner::visitStoreBorrowInst (sbi);
873
+ // If a non-active value is stored into an active buffer, zero-initialize
874
+ // the active buffer's tangent buffer.
875
+ if (!differentialInfo.shouldDifferentiateInstruction (sbi)) {
876
+ if (activityInfo.isActive (sbi->getDest (), getIndices ())) {
877
+ auto &tanBufDest = getTangentBuffer (sbi->getParent (), sbi->getDest ());
878
+ emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
879
+ tanBufDest.getLoc ());
880
+ }
881
+ return ;
882
+ }
883
+ // Otherwise, do standard differential cloning.
845
884
auto &diffBuilder = getDifferentialBuilder ();
846
885
auto loc = sbi->getLoc ();
847
886
auto tanValSrc = materializeTangent (getTangentValue (sbi->getSrc ()), loc);
@@ -852,13 +891,32 @@ class JVPCloner::Implementation final
852
891
// / Handle `copy_addr` instruction.
853
892
// / Original: copy_addr x to y
854
893
// / Tangent: copy_addr tan[x] to tan[y]
855
- CLONE_AND_EMIT_TANGENT (CopyAddr, cai) {
894
+ void visitCopyAddrInst (CopyAddrInst *cai) {
895
+ TypeSubstCloner::visitCopyAddrInst (cai);
896
+ // If a non-active buffer is copied into an active buffer, zero-initialize
897
+ // the destination buffer's tangent buffer.
898
+ // If an active buffer is copied with take into a non-active buffer, destroy
899
+ // the source buffer's tangent buffer.
900
+ if (!differentialInfo.shouldDifferentiateInstruction (cai)) {
901
+ if (activityInfo.isActive (cai->getDest (), getIndices ())) {
902
+ auto &tanBufDest = getTangentBuffer (cai->getParent (), cai->getDest ());
903
+ emitZeroIndirect (tanBufDest->getType ().getASTType (), tanBufDest,
904
+ tanBufDest.getLoc ());
905
+ }
906
+ if (cai->isTakeOfSrc () &&
907
+ activityInfo.isActive (cai->getSrc (), getIndices ())) {
908
+ auto &tanBufSrc = getTangentBuffer (cai->getParent (), cai->getSrc ());
909
+ getDifferentialBuilder ().emitDestroyOperation (tanBufSrc.getLoc (),
910
+ tanBufSrc);
911
+ }
912
+ return ;
913
+ }
914
+ // Otherwise, do standard differential cloning.
856
915
auto diffBuilder = getDifferentialBuilder ();
857
916
auto loc = cai->getLoc ();
858
917
auto *bb = cai->getParent ();
859
918
auto &tanSrc = getTangentBuffer (bb, cai->getSrc ());
860
919
auto tanDest = getTangentBuffer (bb, cai->getDest ());
861
-
862
920
diffBuilder.createCopyAddr (loc, tanSrc, tanDest, cai->isTakeOfSrc (),
863
921
cai->isInitializationOfDest ());
864
922
}
@@ -918,8 +976,8 @@ class JVPCloner::Implementation final
918
976
auto &diffBuilder = getDifferentialBuilder ();
919
977
auto *bb = eai->getParent ();
920
978
auto loc = eai->getLoc ();
921
- auto tanSrc = getTangentBuffer (bb, eai->getOperand ());
922
- diffBuilder.createEndAccess (loc, tanSrc , eai->isAborting ());
979
+ auto tanOperand = getTangentBuffer (bb, eai->getOperand ());
980
+ diffBuilder.createEndAccess (loc, tanOperand , eai->isAborting ());
923
981
}
924
982
925
983
// / Handle `alloc_stack` instruction.
@@ -930,7 +988,7 @@ class JVPCloner::Implementation final
930
988
auto *mappedAllocStackInst = diffBuilder.createAllocStack (
931
989
asi->getLoc (), getRemappedTangentType (asi->getElementType ()),
932
990
asi->getVarInfo ());
933
- bufferMap. try_emplace ({ asi->getParent (), asi} , mappedAllocStackInst);
991
+ setTangentBuffer ( asi->getParent (), asi, mappedAllocStackInst);
934
992
}
935
993
936
994
// / Handle `dealloc_stack` instruction.
@@ -1062,16 +1120,15 @@ class JVPCloner::Implementation final
1062
1120
auto tanType = getRemappedTangentType (tei->getType ());
1063
1121
auto tanSource =
1064
1122
materializeTangent (getTangentValue (tei->getOperand ()), loc);
1065
- SILValue tanBuf;
1066
- // If the tangent buffer of the source does not have a tuple type, then
1123
+ // If the tangent value of the source does not have a tuple type, then
1067
1124
// it must represent a "single element tuple type". Use it directly.
1068
1125
if (!tanSource->getType ().is <TupleType>()) {
1069
1126
setTangentValue (tei->getParent (), tei,
1070
1127
makeConcreteTangentValue (tanSource));
1071
1128
} else {
1072
- tanBuf =
1129
+ auto tanElt =
1073
1130
diffBuilder.createTupleExtract (loc, tanSource, tanIndex, tanType);
1074
- bufferMap. try_emplace ({ tei->getParent (), tei}, tanBuf );
1131
+ setTangentValue ( tei->getParent (), tei, makeConcreteTangentValue (tanElt) );
1075
1132
}
1076
1133
}
1077
1134
@@ -1100,7 +1157,7 @@ class JVPCloner::Implementation final
1100
1157
tanBuf = diffBuilder.createTupleElementAddr (teai->getLoc (), tanSource,
1101
1158
tanIndex, tanType);
1102
1159
}
1103
- bufferMap. try_emplace ({ teai->getParent (), teai} , tanBuf);
1160
+ setTangentBuffer ( teai->getParent (), teai, tanBuf);
1104
1161
}
1105
1162
1106
1163
// / Handle `destructure_tuple` instruction.
@@ -1282,9 +1339,8 @@ class JVPCloner::Implementation final
1282
1339
// Collect original results.
1283
1340
SmallVector<SILValue, 2 > originalResults;
1284
1341
collectAllDirectResultsInTypeOrder (*original, originalResults);
1285
- // Collect differential return elements .
1342
+ // Collect differential direct results .
1286
1343
SmallVector<SILValue, 8 > retElts;
1287
- // for (auto origResult : originalResults) {
1288
1344
for (auto i : range (originalResults.size ())) {
1289
1345
auto origResult = originalResults[i];
1290
1346
if (!getIndices ().results ->contains (i))
@@ -1401,7 +1457,10 @@ JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB,
1401
1457
void JVPCloner::Implementation::prepareForDifferentialGeneration () {
1402
1458
// Create differential blocks and arguments.
1403
1459
auto &differential = getDifferential ();
1460
+ auto diffLoc = differential.getLocation ();
1404
1461
auto *origEntry = original->getEntryBlock ();
1462
+ auto origFnTy = original->getLoweredFunctionType ();
1463
+
1405
1464
for (auto &origBB : *original) {
1406
1465
auto *diffBB = differential.createBasicBlock ();
1407
1466
diffBBMap.insert ({&origBB, diffBB});
@@ -1482,21 +1541,51 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1482
1541
<< " as the tangent of original result " << *origArg);
1483
1542
}
1484
1543
1485
- // Initialize tangent mapping for indirect results.
1486
- auto origIndResults = original->getIndirectResults ();
1544
+ // Initialize tangent mapping for original indirect results and non-wrt
1545
+ // `inout` parameters. The tangent buffers of these address values are
1546
+ // differential indirect results.
1547
+
1548
+ // Collect original results.
1549
+ SmallVector<SILValue, 2 > originalResults;
1550
+ collectAllFormalResultsInTypeOrder (*original, originalResults);
1551
+
1552
+ // Iterate over differentiability results.
1553
+ differentialBuilder.setInsertionPoint (differential.getEntryBlock ());
1487
1554
auto diffIndResults = differential.getIndirectResults ();
1488
- #ifndef NDEBUG
1489
- unsigned numNonWrtInoutParameters = llvm::count_if (
1490
- range (original->getLoweredFunctionType ()->getNumParameters ()),
1491
- [&] (unsigned i) {
1492
- auto ¶mInfo = original->getLoweredFunctionType ()->getParameters ()[i];
1493
- return paramInfo.isIndirectInOut () && !getIndices ().parameters ->contains (i);
1494
- });
1495
- #endif
1496
- assert (origIndResults.size () + numNonWrtInoutParameters == diffIndResults.size ());
1497
- for (auto &origBB : *original)
1498
- for (auto i : indices (origIndResults))
1499
- setTangentBuffer (&origBB, origIndResults[i], diffIndResults[i]);
1555
+ unsigned differentialIndirectResultIndex = 0 ;
1556
+ for (auto resultIndex : getIndices ().results ->getIndices ()) {
1557
+ auto origResult = originalResults[resultIndex];
1558
+ // Handle original formal indirect result.
1559
+ if (resultIndex < origFnTy->getNumResults ()) {
1560
+ // Skip original direct results.
1561
+ if (origResult->getType ().isObject ())
1562
+ continue ;
1563
+ auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1564
+ setTangentBuffer (origEntry, origResult, diffIndResult);
1565
+ // If original indirect result is non-varied, zero-initialize its tangent
1566
+ // buffer.
1567
+ if (!activityInfo.isVaried (origResult, getIndices ().parameters ))
1568
+ emitZeroIndirect (diffIndResult->getType ().getASTType (),
1569
+ diffIndResult, diffLoc);
1570
+ continue ;
1571
+ }
1572
+ // Handle original non-wrt `inout` parameter.
1573
+ // Only original *non-wrt* `inout` parameters have corresponding
1574
+ // differential indirect results.
1575
+ auto inoutParamIndex = resultIndex - origFnTy->getNumResults ();
1576
+ auto inoutParamIt = std::next (
1577
+ origFnTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
1578
+ auto paramIndex =
1579
+ std::distance (origFnTy->getParameters ().begin (), &*inoutParamIt);
1580
+ if (getIndices ().parameters ->contains (paramIndex))
1581
+ continue ;
1582
+ auto diffIndResult = diffIndResults[differentialIndirectResultIndex++];
1583
+ setTangentBuffer (origEntry, origResult, diffIndResult);
1584
+ // Original `inout` parameters are initialized, so their tangent buffers
1585
+ // must also be initialized.
1586
+ emitZeroIndirect (diffIndResult->getType ().getASTType (),
1587
+ diffIndResult, diffLoc);
1588
+ }
1500
1589
}
1501
1590
1502
1591
/* static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential (
@@ -1526,7 +1615,6 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1526
1615
auto origParams = origTy->getParameters ();
1527
1616
auto indices = witness->getSILAutoDiffIndices ();
1528
1617
1529
-
1530
1618
for (auto resultIndex : indices.results ->getIndices ()) {
1531
1619
if (resultIndex < origTy->getNumResults ()) {
1532
1620
// Handle formal original result.
@@ -1539,17 +1627,16 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
1539
1627
->getType ()
1540
1628
->getCanonicalType (witnessCanGenSig),
1541
1629
origResult.getConvention ()));
1542
- }
1543
- else {
1630
+ } else {
1544
1631
// Handle original `inout` parameter.
1545
1632
auto inoutParamIndex = resultIndex - origTy->getNumResults ();
1546
1633
auto inoutParamIt = std::next (
1547
1634
origTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
1548
1635
auto paramIndex =
1549
1636
std::distance (origTy->getParameters ().begin (), &*inoutParamIt);
1550
- // If the original `inout` parameter is a differentiability parameter, then
1551
- // it already has a corresponding differential parameter. Skip adding a
1552
- // corresponding differential result.
1637
+ // If the original `inout` parameter is a differentiability parameter,
1638
+ // then it already has a corresponding differential parameter. Do not add
1639
+ // a corresponding differential result.
1553
1640
if (indices.parameters ->contains (paramIndex))
1554
1641
continue ;
1555
1642
auto inoutParam = origTy->getParameters ()[paramIndex];
0 commit comments