@@ -1816,7 +1816,6 @@ emitAssociatedFunctionReference(
1816
1816
DifferentiationInvoker::Kind::IndirectDifferentiation)
1817
1817
contextualRequirements = std::get<2 >(
1818
1818
invoker.getIndirectDifferentiation ())->getRequirements ();
1819
- // Parent function containing instruction
1820
1819
auto *newAttr = context.getOrCreateDifferentiableAttr (
1821
1820
originalFn, desiredIndices, contextualRequirements);
1822
1821
bool error = context.processDifferentiableAttribute (originalFn, newAttr, invoker);
@@ -2084,8 +2083,23 @@ buildThunkSignature(SILFunction *fn,
2084
2083
2085
2084
}
2086
2085
2086
+ // / The thunk kinds used in the differentiation transform.
2087
2087
enum class DifferentiationThunkKind {
2088
+ // / A reabstraction thunk.
2089
+ // /
2090
+ // / Reabstraction thunks transform a function-typed value to another one with
2091
+ // / different parameter/result abstraction patterns. This is identical to the
2092
+ // / thunks generated by SILGen.
2088
2093
Reabstraction,
2094
+
2095
+ // / An index subset thunk.
2096
+ // /
2097
+ // / An index subset thunk is used transform JVP/VJPs into a version that is
2098
+ // / "wrt" fewer differentiation parameters.
2099
+ // / - Differentials of thunked JVPs use zero for non-requested differentiation
2100
+ // parameters.
2101
+ // / - Pullbacks of thunked VJPs discard results for non-requested
2102
+ // / differentiation parameters.
2089
2103
IndexSubset
2090
2104
};
2091
2105
@@ -5822,10 +5836,6 @@ ADContext::getOrCreateAssociatedFunctionIndexSubsetThunk(
5822
5836
assocSubstMap = assocSubstMap.subst (thunk->getForwardingSubstitutionMap ());
5823
5837
assocFnType = assocRef->getType ().castTo <SILFunctionType>();
5824
5838
5825
- /*
5826
- auto arguments = map<SmallVector<SILValue, 4>>(
5827
- thunk->getArguments(), [](SILValue v) { return v; });
5828
- */
5829
5839
SmallVector<SILValue, 4 > arguments;
5830
5840
arguments.append (thunk->getArguments ().begin (), thunk->getArguments ().end ());
5831
5841
auto *apply = builder.createApply (
@@ -5898,17 +5908,19 @@ SILValue ADContext::getCanonicalizedAutoDiffFunctionInst(
5898
5908
// Construct new curry think.
5899
5909
SILOptFunctionBuilder fb (transform);
5900
5910
auto *newThunk = fb.getOrCreateFunction (
5901
- loc, newThunkName, getSpecializedLinkage (thunk, thunk->getLinkage ()),
5902
- thunkType, thunk->isBare (), thunk->isTransparent (),
5903
- thunk->isSerialized (), thunk->isDynamicallyReplaceable (),
5904
- ProfileCounter (), thunk->isThunk ());
5911
+ loc, newThunkName,
5912
+ getSpecializedLinkage (thunk, thunk->getLinkage ()), thunkType,
5913
+ thunk->isBare (), thunk->isTransparent (), thunk->isSerialized (),
5914
+ thunk->isDynamicallyReplaceable (), ProfileCounter (),
5915
+ thunk->isThunk ());
5905
5916
if (newThunk->empty ()) {
5906
5917
newThunk->setOwnershipEliminated ();
5907
5918
SILFunctionCloner cloner (newThunk);
5908
5919
cloner.cloneFunction (thunk);
5909
5920
}
5910
5921
5911
- auto *retInst = cast<ReturnInst>(newThunk->findReturnBB ()->getTerminator ());
5922
+ auto *retInst =
5923
+ cast<ReturnInst>(newThunk->findReturnBB ()->getTerminator ());
5912
5924
AutoDiffFunctionInst *adfi;
5913
5925
{
5914
5926
SILBuilder builder (retInst);
0 commit comments