@@ -98,7 +98,7 @@ static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
98
98
return nullptr ;
99
99
}
100
100
101
- // / Given a value, find its single `destructure_tuple` user if the value is
101
+ // / Given a value, finds its single `destructure_tuple` user if the value is
102
102
// / tuple-typed and such a user exists.
103
103
static DestructureTupleInst *getSingleDestructureTupleUser (SILValue value) {
104
104
bool foundDestructureTupleUser = false ;
@@ -130,7 +130,7 @@ static void forEachApplyDirectResult(
130
130
resultCallback (result);
131
131
}
132
132
133
- // / Given a function, gather all of its formal results (both direct and
133
+ // / Given a function, gathers all of its formal results (both direct and
134
134
// / indirect) in an order defined by its result type. Note that "formal results"
135
135
// / refer to result values in the body of the function, not at call sites.
136
136
static void
@@ -154,7 +154,7 @@ collectAllFormalResultsInTypeOrder(SILFunction &function,
154
154
: indResults[indResIdx++]);
155
155
}
156
156
157
- // / Given a function, gather all of its direct results in an order defined by
157
+ // / Given a function, gathers all of its direct results in an order defined by
158
158
// / its result type. Note that "formal results" refer to result values in the
159
159
// / body of the function, not at call sites.
160
160
static void
@@ -171,7 +171,7 @@ collectAllDirectResultsInTypeOrder(SILFunction &function,
171
171
results.push_back (retVal);
172
172
}
173
173
174
- // / Given a function call site, gather all of its actual results (both direct
174
+ // / Given a function call site, gathers all of its actual results (both direct
175
175
// / and indirect) in an order defined by its result type.
176
176
static void collectAllActualResultsInTypeOrder (
177
177
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
@@ -291,7 +291,7 @@ static GenericParamList *cloneGenericParameters(ASTContext &ctx,
291
291
return GenericParamList::create (ctx, SourceLoc (), clonedParams, SourceLoc ());
292
292
}
293
293
294
- // / Given an `differentiable_function` instruction, find the corresponding
294
+ // / Given an `differentiable_function` instruction, finds the corresponding
295
295
// / differential operator used in the AST. If no differential operator is found,
296
296
// / return nullptr.
297
297
static DifferentiableFunctionExpr *
@@ -412,15 +412,23 @@ struct DifferentiationInvoker {
412
412
413
413
class DifferentiableActivityInfo ;
414
414
415
- // / Information about the JVP/VJP function produced during JVP/VJP generation,
416
- // / e.g. mappings from original values to corresponding values in the
417
- // / pullback/differential struct.
415
+ // / Linear map struct and branching trace enum information for an original
416
+ // / function and and derivative function (JVP or VJP).
418
417
// /
419
- // / A linear map struct is an aggregate value containing linear maps checkpointed
420
- // / during the JVP/VJP computation. Linear map structs are generated for every
421
- // / original function during JVP/VJP generation. Linear map struct values are
422
- // / constructed by JVP/VJP functions and consumed by pullback/differential
423
- // / functions.
418
+ // / Linear map structs contain all callee linear maps produced in a JVP/VJP
419
+ // / basic block. A linear map struct is created for each basic block in the
420
+ // / original function, and a linear map struct field is created for every active
421
+ // / `apply` in the original basic block.
422
+ // /
423
+ // / Branching trace enums model the control flow graph of the original function.
424
+ // / A branching trace enum is created for each basic block in the original
425
+ // / function, and a branching trace enum case is created for every basic block
426
+ // / predecessor/successor. This supports control flow differentiation: JVP/VJP
427
+ // / functions build branching trace enums to record an execution trace. Indirect
428
+ // / branching trace enums are created for basic blocks that are in loops.
429
+ // /
430
+ // / Linear map struct values and branching trace enum values are constructed in
431
+ // / JVP/VJP functions and consumed in pullback/differential functions.
424
432
class LinearMapInfo {
425
433
private:
426
434
// / The linear map kind.
@@ -446,13 +454,12 @@ class LinearMapInfo {
446
454
// / For differentials: these are successor enums.
447
455
DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
448
456
449
- // / Mapping from `apply` and `struct_extract` instructions in the original
450
- // / function to the corresponding linear map declaration in the linear map
451
- // / struct.
452
- DenseMap<SILInstruction *, VarDecl *> linearMapValueMap;
457
+ // / Mapping from `apply` instructions in the original function to the
458
+ // / corresponding linear map field declaration in the linear map struct.
459
+ DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
453
460
454
- // / Mapping from predecessor+ succcessor basic block pairs in original function
455
- // / to the corresponding branching trace enum case.
461
+ // / Mapping from predecessor- succcessor basic block pairs in the original
462
+ // / function to the corresponding branching trace enum case.
456
463
DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
457
464
branchingTraceEnumCases;
458
465
@@ -505,7 +512,7 @@ class LinearMapInfo {
505
512
llvm_unreachable (" No files?" );
506
513
}
507
514
508
- // / Compute and set the access level for the given nominal type, given the
515
+ // / Computes and sets the access level for the given nominal type, given the
509
516
// / original function linkage.
510
517
void computeAccessLevel (
511
518
NominalTypeDecl *nominal, SILLinkage originalLinkage) {
@@ -661,8 +668,8 @@ class LinearMapInfo {
661
668
return linearMapStruct;
662
669
}
663
670
664
- // / Add a linear map to the linear map struct.
665
- VarDecl *addLinearMapDecl (SILInstruction *inst , SILType linearMapType) {
671
+ // / Adds a linear map field to the linear map struct.
672
+ VarDecl *addLinearMapDecl (ApplyInst *ai , SILType linearMapType) {
666
673
// IRGen requires decls to have AST types (not `SILFunctionType`), so we
667
674
// convert the `SILFunctionType` of the linear map to a `FunctionType` with
668
675
// the same parameters and results.
@@ -678,28 +685,28 @@ class LinearMapInfo {
678
685
astFnTy = FunctionType::get (
679
686
params, silFnTy->getAllResultsInterfaceType ().getASTType ());
680
687
681
- auto *origBB = inst ->getParent ();
688
+ auto *origBB = ai ->getParent ();
682
689
auto *linMapStruct = getLinearMapStruct (origBB);
683
690
std::string linearMapName;
684
691
switch (kind) {
685
692
case AutoDiffLinearMapKind::Differential:
686
- linearMapName = " differential_" + llvm::itostr (linearMapValueMap .size ());
693
+ linearMapName = " differential_" + llvm::itostr (linearMapFieldMap .size ());
687
694
break ;
688
695
case AutoDiffLinearMapKind::Pullback:
689
- linearMapName = " pullback_" + llvm::itostr (linearMapValueMap .size ());
696
+ linearMapName = " pullback_" + llvm::itostr (linearMapFieldMap .size ());
690
697
break ;
691
698
}
692
699
auto *linearMapDecl = addVarDecl (linMapStruct, linearMapName, astFnTy);
693
- linearMapValueMap .insert ({inst , linearMapDecl});
700
+ linearMapFieldMap .insert ({ai , linearMapDecl});
694
701
return linearMapDecl;
695
702
}
696
703
697
- // / Given an `apply` instruction, conditionally adds its linear map function
698
- // / to the linear map struct if it is active.
704
+ // / Given an `apply` instruction, conditionally adds a linear map struct field
705
+ // / for its linear map function if it is active.
699
706
void addLinearMapToStruct (ADContext &context, ApplyInst *ai,
700
707
const SILAutoDiffIndices &indices);
701
708
702
- // / Generate linear map struct and branching enum declarations for the given
709
+ // / Generates linear map struct and branching enum declarations for the given
703
710
// / function. Linear map structs are populated with linear map fields and a
704
711
// / branching enum field.
705
712
void generateDifferentiationDataStructures (
@@ -771,12 +778,13 @@ class LinearMapInfo {
771
778
return linearMapStructEnumFields.lookup (linearMapStruct);
772
779
}
773
780
774
- // / Finds the linear map declaration in the pullback struct for an `apply` or
775
- // / `struct_extract` in the original function.
776
- VarDecl *lookUpLinearMapDecl (SILInstruction *inst) {
777
- auto lookup = linearMapValueMap.find (inst);
778
- assert (lookup != linearMapValueMap.end () &&
779
- " No linear map declaration corresponding to the given instruction" );
781
+ // / Finds the linear map declaration in the pullback struct for the given
782
+ // / `apply` instruction in the original function.
783
+ VarDecl *lookUpLinearMapDecl (ApplyInst *ai) {
784
+ assert (ai->getFunction () == original);
785
+ auto lookup = linearMapFieldMap.find (ai);
786
+ assert (lookup != linearMapFieldMap.end () &&
787
+ " No linear map field corresponding to the given `apply`" );
780
788
return lookup->getSecond ();
781
789
}
782
790
};
@@ -1047,7 +1055,8 @@ class ADContext {
1047
1055
for (auto *da : original->getAttrs ().getAttributes <DifferentiableAttr>()) {
1048
1056
auto *daParamIndices = da->getParameterIndices ();
1049
1057
auto *daIndexSet = autodiff::getLoweredParameterIndices (
1050
- daParamIndices, original->getInterfaceType ()->castTo <AnyFunctionType>());
1058
+ daParamIndices,
1059
+ original->getInterfaceType ()->castTo <AnyFunctionType>());
1051
1060
// If all indices in `indexSet` are in `daIndexSet`, and it has fewer
1052
1061
// indices than our current candidate and a primitive VJP, then `da` is
1053
1062
// our new candidate.
@@ -1646,8 +1655,6 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
1646
1655
return false ;
1647
1656
}
1648
1657
1649
- // / Takes an `apply` instruction and adds its linear map function to the
1650
- // / linear map struct if it is active.
1651
1658
void LinearMapInfo::addLinearMapToStruct (ADContext &context, ApplyInst *ai,
1652
1659
const SILAutoDiffIndices &indices) {
1653
1660
SmallVector<SILValue, 4 > allResults;
@@ -1696,15 +1703,15 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
1696
1703
[&](CanSILFunctionType origFnTy) {
1697
1704
// Check non-differentiable arguments.
1698
1705
for (unsigned paramIndex : range (origFnTy->getNumParameters ())) {
1699
- auto remappedParamType =
1700
- origFnTy-> getParameters ()[paramIndex] .getSILStorageInterfaceType ();
1706
+ auto remappedParamType = origFnTy-> getParameters ()[paramIndex]
1707
+ .getSILStorageInterfaceType ();
1701
1708
if (applyIndices.isWrtParameter (paramIndex) &&
1702
1709
!remappedParamType.isDifferentiable (derivative->getModule ()))
1703
1710
return true ;
1704
1711
}
1705
1712
// Check non-differentiable results.
1706
- auto remappedResultType =
1707
- origFnTy-> getResults ()[applyIndices. source ] .getSILStorageInterfaceType ();
1713
+ auto remappedResultType = origFnTy-> getResults ()[applyIndices. source ]
1714
+ .getSILStorageInterfaceType ();
1708
1715
if (!remappedResultType.isDifferentiable (derivative->getModule ()))
1709
1716
return true ;
1710
1717
return false ;
@@ -1713,13 +1720,13 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
1713
1720
return ;
1714
1721
1715
1722
AutoDiffDerivativeFunctionKind derivativeFnKind (kind);
1716
- auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType (
1717
- parameters, source, derivativeFnKind, context.getTypeConverter (),
1718
- LookUpConformanceInModule (derivative->getModule ().getSwiftModule ()));
1723
+ auto derivativeFnType =
1724
+ remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType (
1725
+ parameters, source, derivativeFnKind, context.getTypeConverter (),
1726
+ LookUpConformanceInModule (derivative->getModule ().getSwiftModule ()));
1719
1727
1720
1728
auto derivativeFnResultTypes =
1721
1729
derivativeFnType->getAllResultsInterfaceType ().castTo <TupleType>();
1722
- derivativeFnResultTypes->getElement (derivativeFnResultTypes->getElements ().size () - 1 );
1723
1730
auto linearMapSILType = SILType::getPrimitiveObjectType (
1724
1731
derivativeFnResultTypes
1725
1732
->getElement (derivativeFnResultTypes->getElements ().size () - 1 )
@@ -1760,8 +1767,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
1760
1767
break ;
1761
1768
}
1762
1769
for (auto &origBB : *original) {
1763
- auto *traceEnum =
1764
- createBranchingTraceDecl ( &origBB, indices, derivativeFnGenSig, loopInfo);
1770
+ auto *traceEnum = createBranchingTraceDecl (
1771
+ &origBB, indices, derivativeFnGenSig, loopInfo);
1765
1772
branchingTraceDecls.insert ({&origBB, traceEnum});
1766
1773
if (origBB.isEntry ())
1767
1774
continue ;
@@ -5847,6 +5854,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5847
5854
// / Record a temporary value for cleanup before its block's terminator.
5848
5855
SILValue recordTemporary (SILValue value) {
5849
5856
assert (value->getType ().isObject ());
5857
+ assert (value->getFunction () == &getPullback ());
5850
5858
blockTemporaries[value->getParentBlock ()].push_back (value);
5851
5859
LLVM_DEBUG (getADDebugStream () << " Recorded temporary " << value);
5852
5860
auto insertion = blockTemporarySet.insert (value); (void )insertion;
@@ -5971,7 +5979,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5971
5979
auto insertion = valueMap.try_emplace ({origBB, originalValue},
5972
5980
adjointValue);
5973
5981
LLVM_DEBUG (getADDebugStream ()
5974
- << " The existing adjoint value will be replaced : "
5982
+ << " The new adjoint value, replacing the existing one, is : "
5975
5983
<< insertion.first ->getSecond ());
5976
5984
if (!insertion.second )
5977
5985
insertion.first ->getSecond () = adjointValue;
@@ -6000,7 +6008,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6000
6008
assert (originalValue->getType ().isObject ());
6001
6009
assert (newAdjointValue.getType ().isObject ());
6002
6010
assert (originalValue->getFunction () == &getOriginal ());
6003
- LLVM_DEBUG (getADDebugStream () << " Adding adjoint for " << originalValue);
6011
+ LLVM_DEBUG (getADDebugStream () << " Adding adjoint value for "
6012
+ << originalValue);
6004
6013
// The adjoint value must be in the tangent space.
6005
6014
assert (newAdjointValue.getType () ==
6006
6015
getRemappedTangentType (originalValue->getType ()));
@@ -6494,8 +6503,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6494
6503
SmallVector<SILValue, 4 > directResults;
6495
6504
auto indirectResultIt = pullback.getIndirectResults ().begin ();
6496
6505
for (auto resultInfo : pullback.getLoweredFunctionType ()->getResults ()) {
6497
- auto resultType =
6498
- pullback. mapTypeIntoContext ( resultInfo.getInterfaceType ())->getCanonicalType ();
6506
+ auto resultType = pullback. mapTypeIntoContext (
6507
+ resultInfo.getInterfaceType ())->getCanonicalType ();
6499
6508
if (resultInfo.isFormalDirect ())
6500
6509
directResults.push_back (emitZeroDirect (resultType, pbLoc));
6501
6510
else
@@ -6539,7 +6548,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6539
6548
auto &predBBActiveValues = activeValues[origPredBB];
6540
6549
for (auto activeValue : predBBActiveValues) {
6541
6550
LLVM_DEBUG (getADDebugStream ()
6542
- << " Propagating active adjoint " << activeValue
6551
+ << " Propagating adjoint of active value " << activeValue
6543
6552
<< " to predecessors' pullback blocks\n " );
6544
6553
if (activeValue->getType ().isObject ()) {
6545
6554
auto activeValueAdj = getAdjointValue (origBB, activeValue);
@@ -7641,7 +7650,7 @@ SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs,
7641
7650
// TODO: Optimize for the case when lhs == rhs.
7642
7651
LLVM_DEBUG (getADDebugStream () <<
7643
7652
" Emitting adjoint accumulation for lhs: " << lhs <<
7644
- " and rhs: " << rhs << " \n " );
7653
+ " and rhs: " << rhs);
7645
7654
assert (lhs->getType () == rhs->getType () && " Adjoints must have equal types!" );
7646
7655
assert (lhs->getType ().isObject () && rhs->getType ().isObject () &&
7647
7656
" Adjoint types must be both object types!" );
0 commit comments