Skip to content

Commit 9382eaf

Browse files
committed
Minor edits.
1 parent 9873f41 commit 9382eaf

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,7 +1816,6 @@ emitAssociatedFunctionReference(
18161816
DifferentiationInvoker::Kind::IndirectDifferentiation)
18171817
contextualRequirements = std::get<2>(
18181818
invoker.getIndirectDifferentiation())->getRequirements();
1819-
// Parent function containing instruction
18201819
auto *newAttr = context.getOrCreateDifferentiableAttr(
18211820
originalFn, desiredIndices, contextualRequirements);
18221821
bool error = context.processDifferentiableAttribute(originalFn, newAttr, invoker);
@@ -2084,8 +2083,23 @@ buildThunkSignature(SILFunction *fn,
20842083

20852084
}
20862085

2086+
/// The thunk kinds used in the differentiation transform.
20872087
enum class DifferentiationThunkKind {
2088+
/// A reabstraction thunk.
2089+
///
2090+
/// Reabstraction thunks transform a function-typed value to another one with
2091+
/// different parameter/result abstraction patterns. This is identical to the
2092+
/// thunks generated by SILGen.
20882093
Reabstraction,
2094+
2095+
/// An index subset thunk.
2096+
///
2097+
/// An index subset thunk is used transform JVP/VJPs into a version that is
2098+
/// "wrt" fewer differentiation parameters.
2099+
/// - Differentials of thunked JVPs use zero for non-requested differentiation
2100+
// parameters.
2101+
/// - Pullbacks of thunked VJPs discard results for non-requested
2102+
/// differentiation parameters.
20892103
IndexSubset
20902104
};
20912105

@@ -5822,10 +5836,6 @@ ADContext::getOrCreateAssociatedFunctionIndexSubsetThunk(
58225836
assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap());
58235837
assocFnType = assocRef->getType().castTo<SILFunctionType>();
58245838

5825-
/*
5826-
auto arguments = map<SmallVector<SILValue, 4>>(
5827-
thunk->getArguments(), [](SILValue v) { return v; });
5828-
*/
58295839
SmallVector<SILValue, 4> arguments;
58305840
arguments.append(thunk->getArguments().begin(), thunk->getArguments().end());
58315841
auto *apply = builder.createApply(
@@ -5898,17 +5908,19 @@ SILValue ADContext::getCanonicalizedAutoDiffFunctionInst(
58985908
// Construct new curry think.
58995909
SILOptFunctionBuilder fb(transform);
59005910
auto *newThunk = fb.getOrCreateFunction(
5901-
loc, newThunkName, getSpecializedLinkage(thunk, thunk->getLinkage()),
5902-
thunkType, thunk->isBare(), thunk->isTransparent(),
5903-
thunk->isSerialized(), thunk->isDynamicallyReplaceable(),
5904-
ProfileCounter(), thunk->isThunk());
5911+
loc, newThunkName,
5912+
getSpecializedLinkage(thunk, thunk->getLinkage()), thunkType,
5913+
thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(),
5914+
thunk->isDynamicallyReplaceable(), ProfileCounter(),
5915+
thunk->isThunk());
59055916
if (newThunk->empty()) {
59065917
newThunk->setOwnershipEliminated();
59075918
SILFunctionCloner cloner(newThunk);
59085919
cloner.cloneFunction(thunk);
59095920
}
59105921

5911-
auto *retInst = cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
5922+
auto *retInst =
5923+
cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
59125924
AutoDiffFunctionInst *adfi;
59135925
{
59145926
SILBuilder builder(retInst);

0 commit comments

Comments
 (0)