Skip to content

Commit 709155a

Browse files
committed
Converting a last few areas to use multiple result indices.
1 parent 80c2e10 commit 709155a

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,11 +1247,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
12471247
diffAttr->getDerivativeGenericSignature()) &&
12481248
"Type-checking should resolve derivative generic signatures for "
12491249
"all original SIL functions with generic signatures");
1250-
auto numResults =
1251-
F->getLoweredFunctionType()->getNumResults() +
1252-
F->getLoweredFunctionType()->getNumIndirectMutatingParameters();
1253-
auto *resultIndices = IndexSubset::getDefault(
1254-
getASTContext(), numResults, /*includeAll*/ true);
1250+
auto *resultIndices =
1251+
autodiff::getAllFunctionSemanticResultIndices(AFD);
12551252
auto witnessGenSig =
12561253
autodiff::getDifferentiabilityWitnessGenericSignature(
12571254
AFD->getGenericSignature(),
@@ -1280,11 +1277,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
12801277
auto witnessGenSig =
12811278
autodiff::getDifferentiabilityWitnessGenericSignature(
12821279
origAFD->getGenericSignature(), AFD->getGenericSignature());
1283-
auto numResults =
1284-
origFn->getLoweredFunctionType()->getNumResults() +
1285-
origFn->getLoweredFunctionType()->getNumIndirectMutatingParameters();
1286-
auto *resultIndices = IndexSubset::getDefault(
1287-
getASTContext(), numResults, /*includeAll*/ true);
1280+
auto *resultIndices =
1281+
autodiff::getAllFunctionSemanticResultIndices(origAFD);
12881282
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
12891283
witnessGenSig);
12901284
emitDifferentiabilityWitness(origAFD, origFn,

lib/SILGen/SILGenThunk.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,13 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk(
547547
SILGenFunctionBuilder builder(*this);
548548
auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction();
549549
Mangle::ASTMangler mangler;
550+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
551+
originalFnDeclRef.getAbstractFunctionDecl());
550552
auto name = mangler.mangleAutoDiffDerivativeFunction(
551553
originalFnDeclRef.getAbstractFunctionDecl(),
552554
derivativeId->getKind(),
553555
AutoDiffConfig(derivativeId->getParameterIndices(),
554-
IndexSubset::get(getASTContext(), 1, {0}),
556+
resultIndices,
555557
derivativeId->getDerivativeGenericSignature()),
556558
/*isVTableThunk*/ true);
557559
auto *thunk = builder.getOrCreateFunction(
@@ -571,7 +573,8 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk(
571573
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
572574
derivativeId->getParameterIndices(),
573575
derivativeFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
574-
auto *loweredResultIndices = IndexSubset::get(getASTContext(), 1, {0});
576+
auto *loweredResultIndices = autodiff::getAllFunctionSemanticResultIndices(
577+
originalFnDeclRef.getAbstractFunctionDecl());
575578
auto diffFn = SGF.B.createDifferentiableFunction(
576579
loc, loweredParamIndices, loweredResultIndices, originalFn);
577580
auto derivativeFn = SGF.B.createDifferentiableFunctionExtract(

lib/TBDGen/TBDGen.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -726,22 +726,27 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
726726

727727
// Add derivative function symbols.
728728
for (const auto *differentiableAttr :
729-
AFD->getAttrs().getAttributes<DifferentiableAttr>())
729+
AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
730+
auto *resultIndices =
731+
autodiff::getAllFunctionSemanticResultIndices(AFD);
730732
addDerivativeConfiguration(
731733
differentiableAttr->getDifferentiabilityKind(),
732734
AFD,
733735
AutoDiffConfig(differentiableAttr->getParameterIndices(),
734-
IndexSubset::get(AFD->getASTContext(), 1, {0}),
736+
resultIndices,
735737
differentiableAttr->getDerivativeGenericSignature()));
738+
}
736739
for (const auto *derivativeAttr :
737-
AFD->getAttrs().getAttributes<DerivativeAttr>())
740+
AFD->getAttrs().getAttributes<DerivativeAttr>()) {
741+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
742+
derivativeAttr->getOriginalFunction(AFD->getASTContext()));
738743
addDerivativeConfiguration(
739744
DifferentiabilityKind::Reverse,
740745
derivativeAttr->getOriginalFunction(AFD->getASTContext()),
741746
AutoDiffConfig(derivativeAttr->getParameterIndices(),
742-
IndexSubset::get(AFD->getASTContext(), 1, {0}),
747+
resultIndices,
743748
AFD->getGenericSignature()));
744-
749+
}
745750
visitDefaultArguments(AFD, AFD->getParameters());
746751

747752
if (AFD->hasAsync()) {

0 commit comments

Comments
 (0)