@@ -5853,6 +5853,10 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
5853
5853
SILFunction *parentThunk, CanSILFunctionType linearMapType,
5854
5854
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
5855
5855
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) {
5856
+ LLVM_DEBUG (getADDebugStream () << " Getting a subset parameters thunk for " <<
5857
+ linearMapType << " from " << actualIndices << " to " <<
5858
+ desiredIndices << ' \n ' );
5859
+
5856
5860
SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap ();
5857
5861
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment ();
5858
5862
auto thunkType = buildThunkType (
@@ -5931,6 +5935,32 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
5931
5935
}
5932
5936
};
5933
5937
5938
+ // `actualIndices` and `desiredIndices` are with respect to the original
5939
+ // function. However, the differential parameters and pullback results may
5940
+ // already be w.r.t. a subset. We create a map between the original function's
5941
+ // actual parameter indices and the linear map's actual indices.
5942
+ // Example:
5943
+ // Original: (T0, T1, T2) -> R
5944
+ // Actual indices: 0, 2
5945
+ // Original differential: (T0, T2) -> R
5946
+ // Original pullback: R -> (T0, T2)
5947
+ // Desired indices w.r.t. original: 2
5948
+ // Desired indices w.r.t. linear map: 1
5949
+ SmallVector<unsigned , 4 > actualParamIndicesMap (
5950
+ actualIndices.parameters ->getCapacity (), UINT_MAX);
5951
+ {
5952
+ unsigned indexInBitVec = 0 ;
5953
+ for (auto index : actualIndices.parameters ->getIndices ()) {
5954
+ actualParamIndicesMap[index] = indexInBitVec;
5955
+ indexInBitVec++;
5956
+ }
5957
+ }
5958
+ auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
5959
+ auto mappedIndex = actualParamIndicesMap[index];
5960
+ assert (mappedIndex < actualIndices.parameters ->getCapacity ());
5961
+ return mappedIndex;
5962
+ };
5963
+
5934
5964
switch (kind) {
5935
5965
// Differential arguments are:
5936
5966
// - All indirect results, followed by:
@@ -5955,7 +5985,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
5955
5985
// Otherwise, construct and use a zero argument.
5956
5986
else {
5957
5987
auto zeroSILType =
5958
- linearMapType->getParameters ()[i].getSILStorageType ();
5988
+ linearMapType->getParameters ()[mapOriginalParameterIndex (i)]
5989
+ .getSILStorageType ();
5959
5990
buildZeroArgument (zeroSILType);
5960
5991
}
5961
5992
}
@@ -5974,7 +6005,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
5974
6005
};
5975
6006
// Iterate over actual indices.
5976
6007
for (unsigned i : actualIndices.parameters ->getIndices ()) {
5977
- auto resultInfo = linearMapType->getResults ()[i];
6008
+ auto resultInfo =
6009
+ linearMapType->getResults ()[mapOriginalParameterIndex (i)];
5978
6010
// Skip direct results. Only indirect results are relevant as arguments.
5979
6011
if (resultInfo.isFormalDirect ())
5980
6012
continue ;
@@ -6022,14 +6054,15 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
6022
6054
// - Do nothing if result is indirect.
6023
6055
// (It was already forwarded to the `apply` instruction).
6024
6056
// - Push it to `results` if result is direct.
6057
+ auto result = allResults[mapOriginalParameterIndex (i)];
6025
6058
if (desiredIndices.isWrtParameter (i)) {
6026
- if (allResults[i] ->getType ().isAddress ())
6059
+ if (result ->getType ().isAddress ())
6027
6060
continue ;
6028
- results.push_back (allResults[i] );
6061
+ results.push_back (result );
6029
6062
}
6030
6063
// Otherwise, cleanup the unused results.
6031
6064
else {
6032
- emitCleanup (builder, loc, allResults[i] );
6065
+ emitCleanup (builder, loc, result );
6033
6066
}
6034
6067
}
6035
6068
// Deallocate local allocations and return final direct result.
@@ -6047,6 +6080,11 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction(
6047
6080
SILValue origFnOperand, SILValue assocFn,
6048
6081
AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices,
6049
6082
SILAutoDiffIndices actualIndices) {
6083
+ LLVM_DEBUG (getADDebugStream () << " Getting a subset parameters thunk for "
6084
+ " associated function " << assocFn << " of the original function "
6085
+ << origFnOperand << " from " << actualIndices << " to " <<
6086
+ desiredIndices << ' \n ' );
6087
+
6050
6088
auto origFnType = origFnOperand->getType ().castTo <SILFunctionType>();
6051
6089
auto &module = getModule ();
6052
6090
auto lookupConformance = LookUpConformanceInModule (module .getSwiftModule ());
@@ -6368,6 +6406,11 @@ void ADContext::foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
6368
6406
}
6369
6407
6370
6408
bool ADContext::processAutoDiffFunctionInst (AutoDiffFunctionInst *adfi) {
6409
+ LLVM_DEBUG ({
6410
+ auto &s = getADDebugStream () << " Processing AutoDiffFunctionInst:\n " ;
6411
+ adfi->printInContext (s);
6412
+ });
6413
+
6371
6414
if (adfi->getNumAssociatedFunctions () ==
6372
6415
autodiff::getNumAutoDiffAssociatedFunctions (
6373
6416
adfi->getDifferentiationOrder ()))
0 commit comments