Skip to content

Commit 72866b6

Browse files
committed
Converting a last few areas to use multiple result indices.
1 parent bb743f8 commit 72866b6

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,11 +1145,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
11451145
diffAttr->getDerivativeGenericSignature()) &&
11461146
"Type-checking should resolve derivative generic signatures for "
11471147
"all original SIL functions with generic signatures");
1148-
auto numResults =
1149-
F->getLoweredFunctionType()->getNumResults() +
1150-
F->getLoweredFunctionType()->getNumIndirectMutatingParameters();
1151-
auto *resultIndices = IndexSubset::getDefault(
1152-
getASTContext(), numResults, /*includeAll*/ true);
1148+
auto *resultIndices =
1149+
autodiff::getAllFunctionSemanticResultIndices(AFD);
11531150
auto witnessGenSig =
11541151
autodiff::getDifferentiabilityWitnessGenericSignature(
11551152
AFD->getGenericSignature(),
@@ -1178,11 +1175,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
11781175
auto witnessGenSig =
11791176
autodiff::getDifferentiabilityWitnessGenericSignature(
11801177
origAFD->getGenericSignature(), AFD->getGenericSignature());
1181-
auto numResults =
1182-
origFn->getLoweredFunctionType()->getNumResults() +
1183-
origFn->getLoweredFunctionType()->getNumIndirectMutatingParameters();
1184-
auto *resultIndices = IndexSubset::getDefault(
1185-
getASTContext(), numResults, /*includeAll*/ true);
1178+
auto *resultIndices =
1179+
autodiff::getAllFunctionSemanticResultIndices(origAFD);
11861180
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
11871181
witnessGenSig);
11881182
emitDifferentiabilityWitness(origAFD, origFn,

lib/SILGen/SILGenThunk.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,13 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk(
490490
SILGenFunctionBuilder builder(*this);
491491
auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction();
492492
Mangle::ASTMangler mangler;
493+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
494+
originalFnDeclRef.getAbstractFunctionDecl());
493495
auto name = mangler.mangleAutoDiffDerivativeFunction(
494496
originalFnDeclRef.getAbstractFunctionDecl(),
495497
derivativeId->getKind(),
496498
AutoDiffConfig(derivativeId->getParameterIndices(),
497-
IndexSubset::get(getASTContext(), 1, {0}),
499+
resultIndices,
498500
derivativeId->getDerivativeGenericSignature()),
499501
/*isVTableThunk*/ true);
500502
auto *thunk = builder.getOrCreateFunction(
@@ -515,7 +517,8 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk(
515517
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
516518
derivativeId->getParameterIndices(),
517519
derivativeFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
518-
auto *loweredResultIndices = IndexSubset::get(getASTContext(), 1, {0});
520+
auto *loweredResultIndices = autodiff::getAllFunctionSemanticResultIndices(
521+
originalFnDeclRef.getAbstractFunctionDecl());
519522
auto diffFn = SGF.B.createDifferentiableFunction(
520523
loc, loweredParamIndices, loweredResultIndices, originalFn);
521524
auto derivativeFn = SGF.B.createDifferentiableFunctionExtract(

lib/TBDGen/TBDGen.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -723,22 +723,27 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
723723

724724
// Add derivative function symbols.
725725
for (const auto *differentiableAttr :
726-
AFD->getAttrs().getAttributes<DifferentiableAttr>())
726+
AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
727+
auto *resultIndices =
728+
autodiff::getAllFunctionSemanticResultIndices(AFD);
727729
addDerivativeConfiguration(
728730
differentiableAttr->getDifferentiabilityKind(),
729731
AFD,
730732
AutoDiffConfig(differentiableAttr->getParameterIndices(),
731-
IndexSubset::get(AFD->getASTContext(), 1, {0}),
733+
resultIndices,
732734
differentiableAttr->getDerivativeGenericSignature()));
735+
}
733736
for (const auto *derivativeAttr :
734-
AFD->getAttrs().getAttributes<DerivativeAttr>())
737+
AFD->getAttrs().getAttributes<DerivativeAttr>()) {
738+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
739+
derivativeAttr->getOriginalFunction(AFD->getASTContext()));
735740
addDerivativeConfiguration(
736741
DifferentiabilityKind::Reverse,
737742
derivativeAttr->getOriginalFunction(AFD->getASTContext()),
738743
AutoDiffConfig(derivativeAttr->getParameterIndices(),
739-
IndexSubset::get(AFD->getASTContext(), 1, {0}),
744+
resultIndices,
740745
AFD->getGenericSignature()));
741-
746+
}
742747
visitDefaultArguments(AFD, AFD->getParameters());
743748

744749
if (AFD->hasAsync()) {

0 commit comments

Comments
 (0)