Skip to content

Commit c83e5aa

Browse files
committed
[AutoDiff] NFC: gardening.
Improve documentation comments and differentiation debug output. Add correctness assertions. Wrap lines to 80 columns. Change the `DenseMap` key type for `LinearMapInfo::linearMapFieldMap` from `SILInstruction *` to `ApplyInst *`, since stored properties (`struct_extract` instructions) cannot have custom derivatives.
1 parent 3b934c5 commit c83e5aa

File tree

1 file changed

+55
-46
lines changed

1 file changed

+55
-46
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,23 @@ struct DifferentiationInvoker {
412412

413413
class DifferentiableActivityInfo;
414414

415-
/// Information about the JVP/VJP function produced during JVP/VJP generation,
416-
/// e.g. mappings from original values to corresponding values in the
417-
/// pullback/differential struct.
415+
/// Linear map struct and branching trace enum information for an original
416+
/// function and and derivative function (JVP or VJP).
418417
///
419-
/// A linear map struct is an aggregate value containing linear maps checkpointed
420-
/// during the JVP/VJP computation. Linear map structs are generated for every
421-
/// original function during JVP/VJP generation. Linear map struct values are
422-
/// constructed by JVP/VJP functions and consumed by pullback/differential
423-
/// functions.
418+
/// Linear map structs contain all callee linear maps produced in a JVP/VJP
419+
/// basic block. A linear map struct is created for each basic block in the
420+
/// original function, and a linear map struct field is created for every active
421+
/// `apply` in the original basic block.
422+
///
423+
/// Branching trace enums model the control flow graph of the original function.
424+
/// A branching trace enum is created for each basic block in the original
425+
/// function, and a branching trace enum case is created for every basic block
426+
/// predecessor/successor. This supports control flow differentiation: JVP/VJP
427+
/// functions build branching trace enums to record an execution trace. Indirect
428+
/// branching trace enums are created for basic blocks that are in loops.
429+
///
430+
/// Linear map struct values and branching trace enum values are constructed in
431+
/// JVP/VJP functions and consumed in pullback/differential functions.
424432
class LinearMapInfo {
425433
private:
426434
/// The linear map kind.
@@ -446,13 +454,12 @@ class LinearMapInfo {
446454
/// For differentials: these are successor enums.
447455
DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
448456

449-
/// Mapping from `apply` and `struct_extract` instructions in the original
450-
/// function to the corresponding linear map declaration in the linear map
451-
/// struct.
452-
DenseMap<SILInstruction *, VarDecl *> linearMapValueMap;
457+
/// Mapping from `apply` instructions in the original function to the
458+
/// corresponding linear map field declaration in the linear map struct.
459+
DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
453460

454-
/// Mapping from predecessor+succcessor basic block pairs in original function
455-
/// to the corresponding branching trace enum case.
461+
/// Mapping from predecessor-succcessor basic block pairs in the original
462+
/// function to the corresponding branching trace enum case.
456463
DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
457464
branchingTraceEnumCases;
458465

@@ -662,7 +669,7 @@ class LinearMapInfo {
662669
}
663670

664671
/// Add a linear map to the linear map struct.
665-
VarDecl *addLinearMapDecl(SILInstruction *inst, SILType linearMapType) {
672+
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
666673
// IRGen requires decls to have AST types (not `SILFunctionType`), so we
667674
// convert the `SILFunctionType` of the linear map to a `FunctionType` with
668675
// the same parameters and results.
@@ -678,24 +685,24 @@ class LinearMapInfo {
678685
astFnTy = FunctionType::get(
679686
params, silFnTy->getAllResultsInterfaceType().getASTType());
680687

681-
auto *origBB = inst->getParent();
688+
auto *origBB = ai->getParent();
682689
auto *linMapStruct = getLinearMapStruct(origBB);
683690
std::string linearMapName;
684691
switch (kind) {
685692
case AutoDiffLinearMapKind::Differential:
686-
linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size());
693+
linearMapName = "differential_" + llvm::itostr(linearMapFieldMap.size());
687694
break;
688695
case AutoDiffLinearMapKind::Pullback:
689-
linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size());
696+
linearMapName = "pullback_" + llvm::itostr(linearMapFieldMap.size());
690697
break;
691698
}
692699
auto *linearMapDecl = addVarDecl(linMapStruct, linearMapName, astFnTy);
693-
linearMapValueMap.insert({inst, linearMapDecl});
700+
linearMapFieldMap.insert({ai, linearMapDecl});
694701
return linearMapDecl;
695702
}
696703

697-
/// Given an `apply` instruction, conditionally adds its linear map function
698-
/// to the linear map struct if it is active.
704+
/// Given an `apply` instruction, conditionally add a linear map struct field
705+
/// for its linear map function if it is active.
699706
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
700707
const SILAutoDiffIndices &indices);
701708

@@ -771,12 +778,13 @@ class LinearMapInfo {
771778
return linearMapStructEnumFields.lookup(linearMapStruct);
772779
}
773780

774-
/// Finds the linear map declaration in the pullback struct for an `apply` or
775-
/// `struct_extract` in the original function.
776-
VarDecl *lookUpLinearMapDecl(SILInstruction *inst) {
777-
auto lookup = linearMapValueMap.find(inst);
778-
assert(lookup != linearMapValueMap.end() &&
779-
"No linear map declaration corresponding to the given instruction");
781+
/// Finds the linear map declaration in the pullback struct for the given
782+
/// `apply` instruction in the original function.
783+
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
784+
assert(ai->getFunction() == original);
785+
auto lookup = linearMapFieldMap.find(ai);
786+
assert(lookup != linearMapFieldMap.end() &&
787+
"No linear map field corresponding to the given `apply`");
780788
return lookup->getSecond();
781789
}
782790
};
@@ -1047,7 +1055,8 @@ class ADContext {
10471055
for (auto *da : original->getAttrs().getAttributes<DifferentiableAttr>()) {
10481056
auto *daParamIndices = da->getParameterIndices();
10491057
auto *daIndexSet = autodiff::getLoweredParameterIndices(
1050-
daParamIndices, original->getInterfaceType()->castTo<AnyFunctionType>());
1058+
daParamIndices,
1059+
original->getInterfaceType()->castTo<AnyFunctionType>());
10511060
// If all indices in `indexSet` are in `daIndexSet`, and it has fewer
10521061
// indices than our current candidate and a primitive VJP, then `da` is
10531062
// our new candidate.
@@ -1646,8 +1655,6 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
16461655
return false;
16471656
}
16481657

1649-
/// Takes an `apply` instruction and adds its linear map function to the
1650-
/// linear map struct if it is active.
16511658
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16521659
const SILAutoDiffIndices &indices) {
16531660
SmallVector<SILValue, 4> allResults;
@@ -1696,15 +1703,15 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16961703
[&](CanSILFunctionType origFnTy) {
16971704
// Check non-differentiable arguments.
16981705
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
1699-
auto remappedParamType =
1700-
origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType();
1706+
auto remappedParamType = origFnTy->getParameters()[paramIndex]
1707+
.getSILStorageInterfaceType();
17011708
if (applyIndices.isWrtParameter(paramIndex) &&
17021709
!remappedParamType.isDifferentiable(derivative->getModule()))
17031710
return true;
17041711
}
17051712
// Check non-differentiable results.
1706-
auto remappedResultType =
1707-
origFnTy->getResults()[applyIndices.source].getSILStorageInterfaceType();
1713+
auto remappedResultType = origFnTy->getResults()[applyIndices.source]
1714+
.getSILStorageInterfaceType();
17081715
if (!remappedResultType.isDifferentiable(derivative->getModule()))
17091716
return true;
17101717
return false;
@@ -1713,13 +1720,13 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
17131720
return;
17141721

17151722
AutoDiffDerivativeFunctionKind derivativeFnKind(kind);
1716-
auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
1717-
parameters, source, derivativeFnKind, context.getTypeConverter(),
1718-
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
1723+
auto derivativeFnType =
1724+
remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
1725+
parameters, source, derivativeFnKind, context.getTypeConverter(),
1726+
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
17191727

17201728
auto derivativeFnResultTypes =
17211729
derivativeFnType->getAllResultsInterfaceType().castTo<TupleType>();
1722-
derivativeFnResultTypes->getElement(derivativeFnResultTypes->getElements().size() - 1);
17231730
auto linearMapSILType = SILType::getPrimitiveObjectType(
17241731
derivativeFnResultTypes
17251732
->getElement(derivativeFnResultTypes->getElements().size() - 1)
@@ -1760,8 +1767,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
17601767
break;
17611768
}
17621769
for (auto &origBB : *original) {
1763-
auto *traceEnum =
1764-
createBranchingTraceDecl(&origBB, indices, derivativeFnGenSig, loopInfo);
1770+
auto *traceEnum = createBranchingTraceDecl(
1771+
&origBB, indices, derivativeFnGenSig, loopInfo);
17651772
branchingTraceDecls.insert({&origBB, traceEnum});
17661773
if (origBB.isEntry())
17671774
continue;
@@ -5847,6 +5854,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
58475854
/// Record a temporary value for cleanup before its block's terminator.
58485855
SILValue recordTemporary(SILValue value) {
58495856
assert(value->getType().isObject());
5857+
assert(value->getFunction() == &getPullback());
58505858
blockTemporaries[value->getParentBlock()].push_back(value);
58515859
LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
58525860
auto insertion = blockTemporarySet.insert(value); (void)insertion;
@@ -5971,7 +5979,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
59715979
auto insertion = valueMap.try_emplace({origBB, originalValue},
59725980
adjointValue);
59735981
LLVM_DEBUG(getADDebugStream()
5974-
<< "The existing adjoint value will be replaced: "
5982+
<< "The new adjoint value, replacing the existing one, is: "
59755983
<< insertion.first->getSecond());
59765984
if (!insertion.second)
59775985
insertion.first->getSecond() = adjointValue;
@@ -6000,7 +6008,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
60006008
assert(originalValue->getType().isObject());
60016009
assert(newAdjointValue.getType().isObject());
60026010
assert(originalValue->getFunction() == &getOriginal());
6003-
LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
6011+
LLVM_DEBUG(getADDebugStream() << "Adding adjoint value for "
6012+
<< originalValue);
60046013
// The adjoint value must be in the tangent space.
60056014
assert(newAdjointValue.getType() ==
60066015
getRemappedTangentType(originalValue->getType()));
@@ -6494,8 +6503,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
64946503
SmallVector<SILValue, 4> directResults;
64956504
auto indirectResultIt = pullback.getIndirectResults().begin();
64966505
for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
6497-
auto resultType =
6498-
pullback.mapTypeIntoContext(resultInfo.getInterfaceType())->getCanonicalType();
6506+
auto resultType = pullback.mapTypeIntoContext(
6507+
resultInfo.getInterfaceType())->getCanonicalType();
64996508
if (resultInfo.isFormalDirect())
65006509
directResults.push_back(emitZeroDirect(resultType, pbLoc));
65016510
else
@@ -6539,7 +6548,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
65396548
auto &predBBActiveValues = activeValues[origPredBB];
65406549
for (auto activeValue : predBBActiveValues) {
65416550
LLVM_DEBUG(getADDebugStream()
6542-
<< "Propagating active adjoint " << activeValue
6551+
<< "Propagating adjoint of active value " << activeValue
65436552
<< " to predecessors' pullback blocks\n");
65446553
if (activeValue->getType().isObject()) {
65456554
auto activeValueAdj = getAdjointValue(origBB, activeValue);
@@ -7641,7 +7650,7 @@ SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs,
76417650
// TODO: Optimize for the case when lhs == rhs.
76427651
LLVM_DEBUG(getADDebugStream() <<
76437652
"Emitting adjoint accumulation for lhs: " << lhs <<
7644-
" and rhs: " << rhs << "\n");
7653+
" and rhs: " << rhs);
76457654
assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!");
76467655
assert(lhs->getType().isObject() && rhs->getType().isObject() &&
76477656
"Adjoint types must be both object types!");

0 commit comments

Comments
 (0)