Skip to content

Commit f372c9f

Browse files
committed
Enable slicing of results in reabstraction thunks
1 parent 0872802 commit f372c9f

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,17 @@ getOrCreateSubsetParametersThunkForLinearMap(
541541
arguments.push_back(indirectResult);
542542
}
543543
// Forward all actual non-indirect-result arguments.
544-
arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(),
545-
thunk->getArgumentsWithoutIndirectResults().end() - 1);
544+
auto thunkArgs = thunk->getArgumentsWithoutIndirectResults();
545+
// Slice out the function to be called
546+
thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1);
547+
unsigned thunkArg = 0;
548+
for (unsigned idx : *actualConfig.resultIndices) {
549+
// Forward result argument in case we do not need to thunk it away
550+
if (desiredConfig.resultIndices->contains(idx))
551+
arguments.push_back(thunkArgs[thunkArg++]);
552+
else // otherwise, zero it out
553+
buildZeroArgument(linearMapType->getParameters()[arguments.size()]);
554+
}
546555
break;
547556
}
548557
}
@@ -552,10 +561,33 @@ getOrCreateSubsetParametersThunkForLinearMap(
552561
auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments);
553562

554563
// If differential thunk, deallocate local allocations and directly return
555-
// `apply` result.
564+
// `apply` result (if it is desired).
556565
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
566+
SmallVector<SILValue, 8> differentialDirectResults;
567+
extractAllElements(ai, builder, differentialDirectResults);
568+
SmallVector<SILValue, 8> allResults;
569+
collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults);
570+
unsigned numResults = thunk->getConventions().getNumDirectSILResults() +
571+
thunk->getConventions().getNumDirectSILResults();
572+
SmallVector<SILValue, 8> results;
573+
for (unsigned idx : *actualConfig.resultIndices) {
574+
if (idx >= numResults)
575+
break;
576+
577+
auto result = allResults[idx];
578+
if (desiredConfig.isWrtResult(idx))
579+
results.push_back(result);
580+
else {
581+
if (result->getType().isAddress())
582+
builder.emitDestroyAddrAndFold(loc, result);
583+
else
584+
builder.emitDestroyValueOperation(loc, result);
585+
}
586+
}
587+
557588
cleanupValues();
558-
builder.createReturn(loc, ai);
589+
auto result = joinElements(results, builder, loc);
590+
builder.createReturn(loc, result);
559591
return {thunk, interfaceSubs};
560592
}
561593

0 commit comments

Comments
 (0)