Skip to content

Commit 4997035

Browse files
committed
Converting a last few areas to use multiple result indices.
1 parent ff8dc58 commit 4997035

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,11 +1256,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
12561256
diffAttr->getDerivativeGenericSignature()) &&
12571257
"Type-checking should resolve derivative generic signatures for "
12581258
"all original SIL functions with generic signatures");
1259-
auto numResults =
1260-
F->getLoweredFunctionType()->getNumResults() +
1261-
F->getLoweredFunctionType()->getNumIndirectMutatingParameters();
1262-
auto *resultIndices = IndexSubset::getDefault(
1263-
getASTContext(), numResults, /*includeAll*/ true);
1259+
auto *resultIndices =
1260+
autodiff::getAllFunctionSemanticResultIndices(AFD);
12641261
auto witnessGenSig =
12651262
autodiff::getDifferentiabilityWitnessGenericSignature(
12661263
AFD->getGenericSignature(),
@@ -1289,11 +1286,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
12891286
auto witnessGenSig =
12901287
autodiff::getDifferentiabilityWitnessGenericSignature(
12911288
origAFD->getGenericSignature(), AFD->getGenericSignature());
1292-
auto numResults =
1293-
origFn->getLoweredFunctionType()->getNumResults() +
1294-
origFn->getLoweredFunctionType()->getNumIndirectMutatingParameters();
1295-
auto *resultIndices = IndexSubset::getDefault(
1296-
getASTContext(), numResults, /*includeAll*/ true);
1289+
auto *resultIndices =
1290+
autodiff::getAllFunctionSemanticResultIndices(origAFD);
12971291
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
12981292
witnessGenSig);
12991293
emitDifferentiabilityWitness(origAFD, origFn,

lib/SILGen/SILGenThunk.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,13 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk(
547547
SILGenFunctionBuilder builder(*this);
548548
auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction();
549549
Mangle::ASTMangler mangler;
550+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
551+
originalFnDeclRef.getAbstractFunctionDecl());
550552
auto name = mangler.mangleAutoDiffDerivativeFunction(
551553
originalFnDeclRef.getAbstractFunctionDecl(),
552554
derivativeId->getKind(),
553555
AutoDiffConfig(derivativeId->getParameterIndices(),
554-
IndexSubset::get(getASTContext(), 1, {0}),
556+
resultIndices,
555557
derivativeId->getDerivativeGenericSignature()),
556558
/*isVTableThunk*/ true);
557559
auto *thunk = builder.getOrCreateFunction(
@@ -571,7 +573,8 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk(
571573
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
572574
derivativeId->getParameterIndices(),
573575
derivativeFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
574-
auto *loweredResultIndices = IndexSubset::get(getASTContext(), 1, {0});
576+
auto *loweredResultIndices = autodiff::getAllFunctionSemanticResultIndices(
577+
originalFnDeclRef.getAbstractFunctionDecl());
575578
auto diffFn = SGF.B.createDifferentiableFunction(
576579
loc, loweredParamIndices, loweredResultIndices, originalFn);
577580
auto derivativeFn = SGF.B.createDifferentiableFunctionExtract(

lib/TBDGen/TBDGen.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -730,22 +730,27 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
730730

731731
// Add derivative function symbols.
732732
for (const auto *differentiableAttr :
733-
AFD->getAttrs().getAttributes<DifferentiableAttr>())
733+
AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
734+
auto *resultIndices =
735+
autodiff::getAllFunctionSemanticResultIndices(AFD);
734736
addDerivativeConfiguration(
735737
differentiableAttr->getDifferentiabilityKind(),
736738
AFD,
737739
AutoDiffConfig(differentiableAttr->getParameterIndices(),
738-
IndexSubset::get(AFD->getASTContext(), 1, {0}),
740+
resultIndices,
739741
differentiableAttr->getDerivativeGenericSignature()));
742+
}
740743
for (const auto *derivativeAttr :
741-
AFD->getAttrs().getAttributes<DerivativeAttr>())
744+
AFD->getAttrs().getAttributes<DerivativeAttr>()) {
745+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
746+
derivativeAttr->getOriginalFunction(AFD->getASTContext()));
742747
addDerivativeConfiguration(
743748
DifferentiabilityKind::Reverse,
744749
derivativeAttr->getOriginalFunction(AFD->getASTContext()),
745750
AutoDiffConfig(derivativeAttr->getParameterIndices(),
746-
IndexSubset::get(AFD->getASTContext(), 1, {0}),
751+
resultIndices,
747752
AFD->getGenericSignature()));
748-
753+
}
749754
visitDefaultArguments(AFD, AFD->getParameters());
750755

751756
if (AFD->hasAsync()) {

0 commit comments

Comments
 (0)