@@ -472,6 +472,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
472
472
return mappedIndex;
473
473
};
474
474
475
+ auto toIndirectResultsIter = thunk->getIndirectResults ().begin ();
476
+ auto useNextIndirectResult = [&]() {
477
+ assert (toIndirectResultsIter != thunk->getIndirectResults ().end ());
478
+ arguments.push_back (*toIndirectResultsIter++);
479
+ };
480
+
475
481
switch (kind) {
476
482
// Differential arguments are:
477
483
// - All indirect results, followed by:
@@ -480,9 +486,29 @@ getOrCreateSubsetParametersThunkForLinearMap(
480
486
// indices).
481
487
// - Zeros (when parameter is not in desired indices).
482
488
case AutoDiffDerivativeFunctionKind::JVP: {
483
- // Forward all indirect results.
484
- arguments.append (thunk->getIndirectResults ().begin (),
485
- thunk->getIndirectResults ().end ());
489
+ unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults ();
490
+ // Forward desired indirect results.
491
+ for (unsigned idx : *actualConfig.resultIndices ) {
492
+ if (idx >= numIndirectResults)
493
+ break ;
494
+
495
+ auto resultInfo = linearMapType->getResults ()[idx];
496
+ assert (idx < linearMapType->getNumResults ());
497
+
498
+ // Forward result argument in case we do not need to thunk it away.
499
+ if (desiredConfig.resultIndices ->contains (idx)) {
500
+ useNextIndirectResult ();
501
+ continue ;
502
+ }
503
+
504
+ // Otherwise, allocate and use an uninitialized indirect result.
505
+ auto *indirectResult = builder.createAllocStack (
506
+ loc, resultInfo.getSILStorageInterfaceType ());
507
+ localAllocations.push_back (indirectResult);
508
+ arguments.push_back (indirectResult);
509
+ }
510
+ assert (toIndirectResultsIter == thunk->getIndirectResults ().end ());
511
+
486
512
auto toArgIter = thunk->getArgumentsWithoutIndirectResults ().begin ();
487
513
auto useNextArgument = [&]() { arguments.push_back (*toArgIter++); };
488
514
// Iterate over actual indices.
@@ -507,10 +533,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
507
533
// - Zeros (when parameter is not in desired indices).
508
534
// - All actual arguments.
509
535
case AutoDiffDerivativeFunctionKind::VJP: {
510
- auto toIndirectResultsIter = thunk->getIndirectResults ().begin ();
511
- auto useNextIndirectResult = [&]() {
512
- arguments.push_back (*toIndirectResultsIter++);
513
- };
514
536
// Collect pullback arguments.
515
537
unsigned pullbackResultIndex = 0 ;
516
538
for (unsigned i : actualConfig.parameterIndices ->getIndices ()) {
@@ -539,8 +561,18 @@ getOrCreateSubsetParametersThunkForLinearMap(
539
561
arguments.push_back (indirectResult);
540
562
}
541
563
// Forward all actual non-indirect-result arguments.
542
- arguments.append (thunk->getArgumentsWithoutIndirectResults ().begin (),
543
- thunk->getArgumentsWithoutIndirectResults ().end () - 1 );
564
+ auto thunkArgs = thunk->getArgumentsWithoutIndirectResults ();
565
+ // Slice out the function to be called.
566
+ thunkArgs = thunkArgs.slice (0 , thunkArgs.size () - 1 );
567
+ unsigned thunkArg = 0 ;
568
+ for (unsigned idx : *actualConfig.resultIndices ) {
569
+ // Forward result argument in case we do not need to thunk it away.
570
+ if (desiredConfig.resultIndices ->contains (idx))
571
+ arguments.push_back (thunkArgs[thunkArg++]);
572
+ else { // Otherwise, zero it out.
573
+ buildZeroArgument (linearMapType->getParameters ()[arguments.size ()]);
574
+ }
575
+ }
544
576
break ;
545
577
}
546
578
}
@@ -550,10 +582,33 @@ getOrCreateSubsetParametersThunkForLinearMap(
550
582
auto *ai = builder.createApply (loc, linearMap, SubstitutionMap (), arguments);
551
583
552
584
// If differential thunk, deallocate local allocations and directly return
553
- // `apply` result.
585
+ // `apply` result (if it is desired) .
554
586
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
587
+ SmallVector<SILValue, 8 > differentialDirectResults;
588
+ extractAllElements (ai, builder, differentialDirectResults);
589
+ SmallVector<SILValue, 8 > allResults;
590
+ collectAllActualResultsInTypeOrder (ai, differentialDirectResults,
591
+ allResults);
592
+ unsigned numResults = thunk->getConventions ().getNumDirectSILResults () +
593
+ thunk->getConventions ().getNumDirectSILResults ();
594
+ SmallVector<SILValue, 8 > results;
595
+ for (unsigned idx : *actualConfig.resultIndices ) {
596
+ if (idx >= numResults)
597
+ break ;
598
+
599
+ auto result = allResults[idx];
600
+ if (desiredConfig.isWrtResult (idx))
601
+ results.push_back (result);
602
+ else {
603
+ if (result->getType ().isAddress ())
604
+ builder.emitDestroyAddrAndFold (loc, result);
605
+ else
606
+ builder.emitDestroyValueOperation (loc, result);
607
+ }
608
+ }
555
609
cleanupValues ();
556
- builder.createReturn (loc, ai);
610
+ auto result = joinElements (results, builder, loc);
611
+ builder.createReturn (loc, result);
557
612
return {thunk, interfaceSubs};
558
613
}
559
614
0 commit comments