@@ -541,8 +541,17 @@ getOrCreateSubsetParametersThunkForLinearMap(
541
541
arguments.push_back (indirectResult);
542
542
}
543
543
// Forward all actual non-indirect-result arguments.
544
- arguments.append (thunk->getArgumentsWithoutIndirectResults ().begin (),
545
- thunk->getArgumentsWithoutIndirectResults ().end () - 1 );
544
+ auto thunkArgs = thunk->getArgumentsWithoutIndirectResults ();
545
+ // Slice out the function to be called
546
+ thunkArgs = thunkArgs.slice (0 , thunkArgs.size () - 1 );
547
+ unsigned thunkArg = 0 ;
548
+ for (unsigned idx : *actualConfig.resultIndices ) {
549
+ // Forward result argument in case we do not need to thunk it away
550
+ if (desiredConfig.resultIndices ->contains (idx))
551
+ arguments.push_back (thunkArgs[thunkArg++]);
552
+ else // otherwise, zero it out
553
+ buildZeroArgument (linearMapType->getParameters ()[arguments.size ()]);
554
+ }
546
555
break ;
547
556
}
548
557
}
@@ -552,10 +561,33 @@ getOrCreateSubsetParametersThunkForLinearMap(
552
561
auto *ai = builder.createApply (loc, linearMap, SubstitutionMap (), arguments);
553
562
554
563
// If differential thunk, deallocate local allocations and directly return
555
- // `apply` result.
564
+ // `apply` result (if it is desired) .
556
565
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
566
+ SmallVector<SILValue, 8 > differentialDirectResults;
567
+ extractAllElements (ai, builder, differentialDirectResults);
568
+ SmallVector<SILValue, 8 > allResults;
569
+ collectAllActualResultsInTypeOrder (ai, differentialDirectResults, allResults);
570
+ unsigned numResults = thunk->getConventions ().getNumDirectSILResults () +
571
+ thunk->getConventions ().getNumDirectSILResults ();
572
+ SmallVector<SILValue, 8 > results;
573
+ for (unsigned idx : *actualConfig.resultIndices ) {
574
+ if (idx >= numResults)
575
+ break ;
576
+
577
+ auto result = allResults[idx];
578
+ if (desiredConfig.isWrtResult (idx))
579
+ results.push_back (result);
580
+ else {
581
+ if (result->getType ().isAddress ())
582
+ builder.emitDestroyAddrAndFold (loc, result);
583
+ else
584
+ builder.emitDestroyValueOperation (loc, result);
585
+ }
586
+ }
587
+
557
588
cleanupValues ();
558
- builder.createReturn (loc, ai);
589
+ auto result = joinElements (results, builder, loc);
590
+ builder.createReturn (loc, result);
559
591
return {thunk, interfaceSubs};
560
592
}
561
593
0 commit comments