Skip to content

Commit 4ce1aaa

Browse files
authored
[AutoDiff] Fix a subset-parameters thunk parameter indexing crasher. (#25699)
The subset parameters thunk generation logic is not handling cases where the original JVP/JVP is already producing a subset-parameters linear map. This patch fixes it by computing indices w.r.t. the linear map from indices w.r.t. the original function. Resolves [TF-594](https://bugs.swift.org/browse/TF-594) and unblocks tensorflow/swift-apis#147.
1 parent 0547eb5 commit 4ce1aaa

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5853,6 +5853,10 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
58535853
SILFunction *parentThunk, CanSILFunctionType linearMapType,
58545854
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
58555855
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) {
5856+
LLVM_DEBUG(getADDebugStream() << "Getting a subset parameters thunk for " <<
5857+
linearMapType << " from " << actualIndices << " to " <<
5858+
desiredIndices << '\n');
5859+
58565860
SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap();
58575861
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment();
58585862
auto thunkType = buildThunkType(
@@ -5931,6 +5935,32 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
59315935
}
59325936
};
59335937

5938+
// `actualIndices` and `desiredIndices` are with respect to the original
5939+
// function. However, the differential parameters and pullback results may
5940+
// already be w.r.t. a subset. We create a map between the original function's
5941+
// actual parameter indices and the linear map's actual indices.
5942+
// Example:
5943+
// Original: (T0, T1, T2) -> R
5944+
// Actual indices: 0, 2
5945+
// Original differential: (T0, T2) -> R
5946+
// Original pullback: R -> (T0, T2)
5947+
// Desired indices w.r.t. original: 2
5948+
// Desired indices w.r.t. linear map: 1
5949+
SmallVector<unsigned, 4> actualParamIndicesMap(
5950+
actualIndices.parameters->getCapacity(), UINT_MAX);
5951+
{
5952+
unsigned indexInBitVec = 0;
5953+
for (auto index : actualIndices.parameters->getIndices()) {
5954+
actualParamIndicesMap[index] = indexInBitVec;
5955+
indexInBitVec++;
5956+
}
5957+
}
5958+
auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
5959+
auto mappedIndex = actualParamIndicesMap[index];
5960+
assert(mappedIndex < actualIndices.parameters->getCapacity());
5961+
return mappedIndex;
5962+
};
5963+
59345964
switch (kind) {
59355965
// Differential arguments are:
59365966
// - All indirect results, followed by:
@@ -5955,7 +5985,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
59555985
// Otherwise, construct and use a zero argument.
59565986
else {
59575987
auto zeroSILType =
5958-
linearMapType->getParameters()[i].getSILStorageType();
5988+
linearMapType->getParameters()[mapOriginalParameterIndex(i)]
5989+
.getSILStorageType();
59595990
buildZeroArgument(zeroSILType);
59605991
}
59615992
}
@@ -5974,7 +6005,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
59746005
};
59756006
// Iterate over actual indices.
59766007
for (unsigned i : actualIndices.parameters->getIndices()) {
5977-
auto resultInfo = linearMapType->getResults()[i];
6008+
auto resultInfo =
6009+
linearMapType->getResults()[mapOriginalParameterIndex(i)];
59786010
// Skip direct results. Only indirect results are relevant as arguments.
59796011
if (resultInfo.isFormalDirect())
59806012
continue;
@@ -6022,14 +6054,15 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
60226054
// - Do nothing if result is indirect.
60236055
// (It was already forwarded to the `apply` instruction).
60246056
// - Push it to `results` if result is direct.
6057+
auto result = allResults[mapOriginalParameterIndex(i)];
60256058
if (desiredIndices.isWrtParameter(i)) {
6026-
if (allResults[i]->getType().isAddress())
6059+
if (result->getType().isAddress())
60276060
continue;
6028-
results.push_back(allResults[i]);
6061+
results.push_back(result);
60296062
}
60306063
// Otherwise, cleanup the unused results.
60316064
else {
6032-
emitCleanup(builder, loc, allResults[i]);
6065+
emitCleanup(builder, loc, result);
60336066
}
60346067
}
60356068
// Deallocate local allocations and return final direct result.
@@ -6047,6 +6080,11 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction(
60476080
SILValue origFnOperand, SILValue assocFn,
60486081
AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices,
60496082
SILAutoDiffIndices actualIndices) {
6083+
LLVM_DEBUG(getADDebugStream() << "Getting a subset parameters thunk for "
6084+
"associated function " << assocFn << " of the original function "
6085+
<< origFnOperand << " from " << actualIndices << " to " <<
6086+
desiredIndices << '\n');
6087+
60506088
auto origFnType = origFnOperand->getType().castTo<SILFunctionType>();
60516089
auto &module = getModule();
60526090
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
@@ -6368,6 +6406,11 @@ void ADContext::foldAutoDiffFunctionExtraction(AutoDiffFunctionInst *source) {
63686406
}
63696407

63706408
bool ADContext::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
6409+
LLVM_DEBUG({
6410+
auto &s = getADDebugStream() << "Processing AutoDiffFunctionInst:\n";
6411+
adfi->printInContext(s);
6412+
});
6413+
63716414
if (adfi->getNumAssociatedFunctions() ==
63726415
autodiff::getNumAutoDiffAssociatedFunctions(
63736416
adfi->getDifferentiationOrder()))

test/AutoDiff/superset_adjoint.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ SupersetVJPTests.test("CrossModuleClosure") {
3333
expectEqual(1, gradient(at: Float(1)) { x in x + 2 })
3434
}
3535

36+
SupersetVJPTests.test("SubsetOfSubset") {
37+
@differentiable(wrt: (x, z))
38+
func foo(_ x: Float, _ y: Float, _ z: Float) -> Float {
39+
withoutDerivative(at: 0)
40+
}
41+
expectEqual(0, gradient(at: 0, in: { x in foo(x, 0, 0) }))
42+
}
43+
3644
// FIXME: The expression `(+) as @differentiable (Float, @nondiff Float) -> Float)`
3745
// forms a curry thunk of `Float.+` before conversion to @differentiable, and AD
3846
// doesn't know how to differentiate the curry thunk, so it produces a

0 commit comments

Comments
 (0)