Skip to content

Commit fe2df1e

Browse files
committed
Properly slice indirect results as well
1 parent f372c9f commit fe2df1e

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ void collectMinimalIndicesForFunctionCall(
232232
auto &param = paramAndIdx.value();
233233
if (!param.isIndirectMutating())
234234
continue;
235-
unsigned idx = paramAndIdx.index();
235+
unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
236236
auto inoutArg = ai->getArgument(idx);
237237
results.push_back(inoutArg);
238238
resultIndices.push_back(inoutParamResultIndex++);

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
474474
return mappedIndex;
475475
};
476476

477+
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
478+
auto useNextIndirectResult = [&]() {
479+
assert(toIndirectResultsIter != thunk->getIndirectResults().end());
480+
arguments.push_back(*toIndirectResultsIter++);
481+
};
482+
477483
switch (kind) {
478484
// Differential arguments are:
479485
// - All indirect results, followed by:
@@ -482,9 +488,29 @@ getOrCreateSubsetParametersThunkForLinearMap(
482488
// indices).
483489
// - Zeros (when parameter is not in desired indices).
484490
case AutoDiffDerivativeFunctionKind::JVP: {
485-
// Forward all indirect results.
486-
arguments.append(thunk->getIndirectResults().begin(),
487-
thunk->getIndirectResults().end());
491+
unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults();
492+
// Forward desired indirect results
493+
for (unsigned idx : *actualConfig.resultIndices) {
494+
if (idx >= numIndirectResults)
495+
break;
496+
497+
auto resultInfo = linearMapType->getResults()[idx];
498+
assert(idx < linearMapType->getNumResults());
499+
500+
// Forward result argument in case we do not need to thunk it away
501+
if (desiredConfig.resultIndices->contains(idx)) {
502+
useNextIndirectResult();
503+
continue;
504+
}
505+
506+
// Otherwise, allocate and use an uninitialized indirect result
507+
auto *indirectResult = builder.createAllocStack(
508+
loc, resultInfo.getSILStorageInterfaceType());
509+
localAllocations.push_back(indirectResult);
510+
arguments.push_back(indirectResult);
511+
}
512+
assert(toIndirectResultsIter == thunk->getIndirectResults().end());
513+
488514
auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
489515
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
490516
// Iterate over actual indices.
@@ -509,10 +535,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
509535
// - Zeros (when parameter is not in desired indices).
510536
// - All actual arguments.
511537
case AutoDiffDerivativeFunctionKind::VJP: {
512-
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
513-
auto useNextIndirectResult = [&]() {
514-
arguments.push_back(*toIndirectResultsIter++);
515-
};
516538
// Collect pullback arguments.
517539
unsigned pullbackResultIndex = 0;
518540
for (unsigned i : actualConfig.parameterIndices->getIndices()) {

0 commit comments

Comments
 (0)