@@ -474,6 +474,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
474
474
return mappedIndex;
475
475
};
476
476
477
+ auto toIndirectResultsIter = thunk->getIndirectResults ().begin ();
478
+ auto useNextIndirectResult = [&]() {
479
+ assert (toIndirectResultsIter != thunk->getIndirectResults ().end ());
480
+ arguments.push_back (*toIndirectResultsIter++);
481
+ };
482
+
477
483
switch (kind) {
478
484
// Differential arguments are:
479
485
// - All indirect results, followed by:
@@ -482,9 +488,29 @@ getOrCreateSubsetParametersThunkForLinearMap(
482
488
// indices).
483
489
// - Zeros (when parameter is not in desired indices).
484
490
case AutoDiffDerivativeFunctionKind::JVP: {
485
- // Forward all indirect results.
486
- arguments.append (thunk->getIndirectResults ().begin (),
487
- thunk->getIndirectResults ().end ());
491
+ unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults ();
492
+ // Forward desired indirect results
493
+ for (unsigned idx : *actualConfig.resultIndices ) {
494
+ if (idx >= numIndirectResults)
495
+ break ;
496
+
497
+ auto resultInfo = linearMapType->getResults ()[idx];
498
+ assert (idx < linearMapType->getNumResults ());
499
+
500
+ // Forward result argument in case we do not need to thunk it away
501
+ if (desiredConfig.resultIndices ->contains (idx)) {
502
+ useNextIndirectResult ();
503
+ continue ;
504
+ }
505
+
506
+ // Otherwise, allocate and use an uninitialized indirect result
507
+ auto *indirectResult = builder.createAllocStack (
508
+ loc, resultInfo.getSILStorageInterfaceType ());
509
+ localAllocations.push_back (indirectResult);
510
+ arguments.push_back (indirectResult);
511
+ }
512
+ assert (toIndirectResultsIter == thunk->getIndirectResults ().end ());
513
+
488
514
auto toArgIter = thunk->getArgumentsWithoutIndirectResults ().begin ();
489
515
auto useNextArgument = [&]() { arguments.push_back (*toArgIter++); };
490
516
// Iterate over actual indices.
@@ -509,10 +535,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
509
535
// - Zeros (when parameter is not in desired indices).
510
536
// - All actual arguments.
511
537
case AutoDiffDerivativeFunctionKind::VJP: {
512
- auto toIndirectResultsIter = thunk->getIndirectResults ().begin ();
513
- auto useNextIndirectResult = [&]() {
514
- arguments.push_back (*toIndirectResultsIter++);
515
- };
516
538
// Collect pullback arguments.
517
539
unsigned pullbackResultIndex = 0 ;
518
540
for (unsigned i : actualConfig.parameterIndices ->getIndices ()) {
0 commit comments