Skip to content

Commit d89e9d1

Browse files
authored
[AutoDiff] NFC: gardening. (#28251)
Improve documentation comments and differentiation debug output. Add correctness assertions. Wrap lines to 80 columns. Change the `DenseMap` key type for `LinearMapInfo::linearMapFieldMap` from `SILInstruction *` to `ApplyInst *`, since stored properties (`struct_extract` instructions) cannot have custom derivatives.
1 parent 54ba428 commit d89e9d1

File tree

1 file changed

+63
-54
lines changed

1 file changed

+63
-54
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
9898
return nullptr;
9999
}
100100

101-
/// Given a value, find its single `destructure_tuple` user if the value is
101+
/// Given a value, finds its single `destructure_tuple` user if the value is
102102
/// tuple-typed and such a user exists.
103103
static DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
104104
bool foundDestructureTupleUser = false;
@@ -130,7 +130,7 @@ static void forEachApplyDirectResult(
130130
resultCallback(result);
131131
}
132132

133-
/// Given a function, gather all of its formal results (both direct and
133+
/// Given a function, gathers all of its formal results (both direct and
134134
/// indirect) in an order defined by its result type. Note that "formal results"
135135
/// refer to result values in the body of the function, not at call sites.
136136
static void
@@ -154,7 +154,7 @@ collectAllFormalResultsInTypeOrder(SILFunction &function,
154154
: indResults[indResIdx++]);
155155
}
156156

157-
/// Given a function, gather all of its direct results in an order defined by
157+
/// Given a function, gathers all of its direct results in an order defined by
158158
/// its result type. Note that "formal results" refer to result values in the
159159
/// body of the function, not at call sites.
160160
static void
@@ -171,7 +171,7 @@ collectAllDirectResultsInTypeOrder(SILFunction &function,
171171
results.push_back(retVal);
172172
}
173173

174-
/// Given a function call site, gather all of its actual results (both direct
174+
/// Given a function call site, gathers all of its actual results (both direct
175175
/// and indirect) in an order defined by its result type.
176176
static void collectAllActualResultsInTypeOrder(
177177
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
@@ -291,7 +291,7 @@ static GenericParamList *cloneGenericParameters(ASTContext &ctx,
291291
return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc());
292292
}
293293

294-
/// Given an `differentiable_function` instruction, find the corresponding
294+
/// Given an `differentiable_function` instruction, finds the corresponding
295295
/// differential operator used in the AST. If no differential operator is found,
296296
/// return nullptr.
297297
static DifferentiableFunctionExpr *
@@ -412,15 +412,23 @@ struct DifferentiationInvoker {
412412

413413
class DifferentiableActivityInfo;
414414

415-
/// Information about the JVP/VJP function produced during JVP/VJP generation,
416-
/// e.g. mappings from original values to corresponding values in the
417-
/// pullback/differential struct.
415+
/// Linear map struct and branching trace enum information for an original
416+
/// function and and derivative function (JVP or VJP).
418417
///
419-
/// A linear map struct is an aggregate value containing linear maps checkpointed
420-
/// during the JVP/VJP computation. Linear map structs are generated for every
421-
/// original function during JVP/VJP generation. Linear map struct values are
422-
/// constructed by JVP/VJP functions and consumed by pullback/differential
423-
/// functions.
418+
/// Linear map structs contain all callee linear maps produced in a JVP/VJP
419+
/// basic block. A linear map struct is created for each basic block in the
420+
/// original function, and a linear map struct field is created for every active
421+
/// `apply` in the original basic block.
422+
///
423+
/// Branching trace enums model the control flow graph of the original function.
424+
/// A branching trace enum is created for each basic block in the original
425+
/// function, and a branching trace enum case is created for every basic block
426+
/// predecessor/successor. This supports control flow differentiation: JVP/VJP
427+
/// functions build branching trace enums to record an execution trace. Indirect
428+
/// branching trace enums are created for basic blocks that are in loops.
429+
///
430+
/// Linear map struct values and branching trace enum values are constructed in
431+
/// JVP/VJP functions and consumed in pullback/differential functions.
424432
class LinearMapInfo {
425433
private:
426434
/// The linear map kind.
@@ -446,13 +454,12 @@ class LinearMapInfo {
446454
/// For differentials: these are successor enums.
447455
DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
448456

449-
/// Mapping from `apply` and `struct_extract` instructions in the original
450-
/// function to the corresponding linear map declaration in the linear map
451-
/// struct.
452-
DenseMap<SILInstruction *, VarDecl *> linearMapValueMap;
457+
/// Mapping from `apply` instructions in the original function to the
458+
/// corresponding linear map field declaration in the linear map struct.
459+
DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;
453460

454-
/// Mapping from predecessor+succcessor basic block pairs in original function
455-
/// to the corresponding branching trace enum case.
461+
/// Mapping from predecessor-succcessor basic block pairs in the original
462+
/// function to the corresponding branching trace enum case.
456463
DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
457464
branchingTraceEnumCases;
458465

@@ -505,7 +512,7 @@ class LinearMapInfo {
505512
llvm_unreachable("No files?");
506513
}
507514

508-
/// Compute and set the access level for the given nominal type, given the
515+
/// Computes and sets the access level for the given nominal type, given the
509516
/// original function linkage.
510517
void computeAccessLevel(
511518
NominalTypeDecl *nominal, SILLinkage originalLinkage) {
@@ -661,8 +668,8 @@ class LinearMapInfo {
661668
return linearMapStruct;
662669
}
663670

664-
/// Add a linear map to the linear map struct.
665-
VarDecl *addLinearMapDecl(SILInstruction *inst, SILType linearMapType) {
671+
/// Adds a linear map field to the linear map struct.
672+
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
666673
// IRGen requires decls to have AST types (not `SILFunctionType`), so we
667674
// convert the `SILFunctionType` of the linear map to a `FunctionType` with
668675
// the same parameters and results.
@@ -678,28 +685,28 @@ class LinearMapInfo {
678685
astFnTy = FunctionType::get(
679686
params, silFnTy->getAllResultsInterfaceType().getASTType());
680687

681-
auto *origBB = inst->getParent();
688+
auto *origBB = ai->getParent();
682689
auto *linMapStruct = getLinearMapStruct(origBB);
683690
std::string linearMapName;
684691
switch (kind) {
685692
case AutoDiffLinearMapKind::Differential:
686-
linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size());
693+
linearMapName = "differential_" + llvm::itostr(linearMapFieldMap.size());
687694
break;
688695
case AutoDiffLinearMapKind::Pullback:
689-
linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size());
696+
linearMapName = "pullback_" + llvm::itostr(linearMapFieldMap.size());
690697
break;
691698
}
692699
auto *linearMapDecl = addVarDecl(linMapStruct, linearMapName, astFnTy);
693-
linearMapValueMap.insert({inst, linearMapDecl});
700+
linearMapFieldMap.insert({ai, linearMapDecl});
694701
return linearMapDecl;
695702
}
696703

697-
/// Given an `apply` instruction, conditionally adds its linear map function
698-
/// to the linear map struct if it is active.
704+
/// Given an `apply` instruction, conditionally adds a linear map struct field
705+
/// for its linear map function if it is active.
699706
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
700707
const SILAutoDiffIndices &indices);
701708

702-
/// Generate linear map struct and branching enum declarations for the given
709+
/// Generates linear map struct and branching enum declarations for the given
703710
/// function. Linear map structs are populated with linear map fields and a
704711
/// branching enum field.
705712
void generateDifferentiationDataStructures(
@@ -771,12 +778,13 @@ class LinearMapInfo {
771778
return linearMapStructEnumFields.lookup(linearMapStruct);
772779
}
773780

774-
/// Finds the linear map declaration in the pullback struct for an `apply` or
775-
/// `struct_extract` in the original function.
776-
VarDecl *lookUpLinearMapDecl(SILInstruction *inst) {
777-
auto lookup = linearMapValueMap.find(inst);
778-
assert(lookup != linearMapValueMap.end() &&
779-
"No linear map declaration corresponding to the given instruction");
781+
/// Finds the linear map declaration in the pullback struct for the given
782+
/// `apply` instruction in the original function.
783+
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
784+
assert(ai->getFunction() == original);
785+
auto lookup = linearMapFieldMap.find(ai);
786+
assert(lookup != linearMapFieldMap.end() &&
787+
"No linear map field corresponding to the given `apply`");
780788
return lookup->getSecond();
781789
}
782790
};
@@ -1047,7 +1055,8 @@ class ADContext {
10471055
for (auto *da : original->getAttrs().getAttributes<DifferentiableAttr>()) {
10481056
auto *daParamIndices = da->getParameterIndices();
10491057
auto *daIndexSet = autodiff::getLoweredParameterIndices(
1050-
daParamIndices, original->getInterfaceType()->castTo<AnyFunctionType>());
1058+
daParamIndices,
1059+
original->getInterfaceType()->castTo<AnyFunctionType>());
10511060
// If all indices in `indexSet` are in `daIndexSet`, and it has fewer
10521061
// indices than our current candidate and a primitive VJP, then `da` is
10531062
// our new candidate.
@@ -1646,8 +1655,6 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
16461655
return false;
16471656
}
16481657

1649-
/// Takes an `apply` instruction and adds its linear map function to the
1650-
/// linear map struct if it is active.
16511658
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16521659
const SILAutoDiffIndices &indices) {
16531660
SmallVector<SILValue, 4> allResults;
@@ -1696,15 +1703,15 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
16961703
[&](CanSILFunctionType origFnTy) {
16971704
// Check non-differentiable arguments.
16981705
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
1699-
auto remappedParamType =
1700-
origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType();
1706+
auto remappedParamType = origFnTy->getParameters()[paramIndex]
1707+
.getSILStorageInterfaceType();
17011708
if (applyIndices.isWrtParameter(paramIndex) &&
17021709
!remappedParamType.isDifferentiable(derivative->getModule()))
17031710
return true;
17041711
}
17051712
// Check non-differentiable results.
1706-
auto remappedResultType =
1707-
origFnTy->getResults()[applyIndices.source].getSILStorageInterfaceType();
1713+
auto remappedResultType = origFnTy->getResults()[applyIndices.source]
1714+
.getSILStorageInterfaceType();
17081715
if (!remappedResultType.isDifferentiable(derivative->getModule()))
17091716
return true;
17101717
return false;
@@ -1713,13 +1720,13 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
17131720
return;
17141721

17151722
AutoDiffDerivativeFunctionKind derivativeFnKind(kind);
1716-
auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
1717-
parameters, source, derivativeFnKind, context.getTypeConverter(),
1718-
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
1723+
auto derivativeFnType =
1724+
remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
1725+
parameters, source, derivativeFnKind, context.getTypeConverter(),
1726+
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
17191727

17201728
auto derivativeFnResultTypes =
17211729
derivativeFnType->getAllResultsInterfaceType().castTo<TupleType>();
1722-
derivativeFnResultTypes->getElement(derivativeFnResultTypes->getElements().size() - 1);
17231730
auto linearMapSILType = SILType::getPrimitiveObjectType(
17241731
derivativeFnResultTypes
17251732
->getElement(derivativeFnResultTypes->getElements().size() - 1)
@@ -1760,8 +1767,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
17601767
break;
17611768
}
17621769
for (auto &origBB : *original) {
1763-
auto *traceEnum =
1764-
createBranchingTraceDecl(&origBB, indices, derivativeFnGenSig, loopInfo);
1770+
auto *traceEnum = createBranchingTraceDecl(
1771+
&origBB, indices, derivativeFnGenSig, loopInfo);
17651772
branchingTraceDecls.insert({&origBB, traceEnum});
17661773
if (origBB.isEntry())
17671774
continue;
@@ -5847,6 +5854,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
58475854
/// Record a temporary value for cleanup before its block's terminator.
58485855
SILValue recordTemporary(SILValue value) {
58495856
assert(value->getType().isObject());
5857+
assert(value->getFunction() == &getPullback());
58505858
blockTemporaries[value->getParentBlock()].push_back(value);
58515859
LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
58525860
auto insertion = blockTemporarySet.insert(value); (void)insertion;
@@ -5971,7 +5979,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
59715979
auto insertion = valueMap.try_emplace({origBB, originalValue},
59725980
adjointValue);
59735981
LLVM_DEBUG(getADDebugStream()
5974-
<< "The existing adjoint value will be replaced: "
5982+
<< "The new adjoint value, replacing the existing one, is: "
59755983
<< insertion.first->getSecond());
59765984
if (!insertion.second)
59775985
insertion.first->getSecond() = adjointValue;
@@ -6000,7 +6008,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
60006008
assert(originalValue->getType().isObject());
60016009
assert(newAdjointValue.getType().isObject());
60026010
assert(originalValue->getFunction() == &getOriginal());
6003-
LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
6011+
LLVM_DEBUG(getADDebugStream() << "Adding adjoint value for "
6012+
<< originalValue);
60046013
// The adjoint value must be in the tangent space.
60056014
assert(newAdjointValue.getType() ==
60066015
getRemappedTangentType(originalValue->getType()));
@@ -6494,8 +6503,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
64946503
SmallVector<SILValue, 4> directResults;
64956504
auto indirectResultIt = pullback.getIndirectResults().begin();
64966505
for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
6497-
auto resultType =
6498-
pullback.mapTypeIntoContext(resultInfo.getInterfaceType())->getCanonicalType();
6506+
auto resultType = pullback.mapTypeIntoContext(
6507+
resultInfo.getInterfaceType())->getCanonicalType();
64996508
if (resultInfo.isFormalDirect())
65006509
directResults.push_back(emitZeroDirect(resultType, pbLoc));
65016510
else
@@ -6539,7 +6548,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
65396548
auto &predBBActiveValues = activeValues[origPredBB];
65406549
for (auto activeValue : predBBActiveValues) {
65416550
LLVM_DEBUG(getADDebugStream()
6542-
<< "Propagating active adjoint " << activeValue
6551+
<< "Propagating adjoint of active value " << activeValue
65436552
<< " to predecessors' pullback blocks\n");
65446553
if (activeValue->getType().isObject()) {
65456554
auto activeValueAdj = getAdjointValue(origBB, activeValue);
@@ -7641,7 +7650,7 @@ SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs,
76417650
// TODO: Optimize for the case when lhs == rhs.
76427651
LLVM_DEBUG(getADDebugStream() <<
76437652
"Emitting adjoint accumulation for lhs: " << lhs <<
7644-
" and rhs: " << rhs << "\n");
7653+
" and rhs: " << rhs);
76457654
assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!");
76467655
assert(lhs->getType().isObject() && rhs->getType().isObject() &&
76477656
"Adjoint types must be both object types!");

0 commit comments

Comments
 (0)