Skip to content

Commit 80c2e10

Browse files
committed
Consolidated repeated result index generation into a central function.
1 parent 79e9b2e commit 80c2e10

File tree

7 files changed

+31
-35
lines changed

7 files changed

+31
-35
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
namespace swift {
3434

35+
class AbstractFunctionDecl;
3536
class AnyFunctionType;
3637
class SourceFile;
3738
class SILFunctionType;
@@ -575,6 +576,10 @@ void getFunctionSemanticResultTypes(
575576
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
576577
GenericEnvironment *genericEnv = nullptr);
577578

579+
/// Returns the indices of all semantic results for a given function.
580+
IndexSubset *getAllFunctionSemanticResultIndices(
581+
const AbstractFunctionDecl *AFD);
582+
578583
/// Returns the lowered SIL parameter indices for the given AST parameter
579584
/// indices and `AnyfunctionType`.
580585
///

lib/AST/AutoDiff.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ void autodiff::getFunctionSemanticResultTypes(
211211
}
212212
}
213213

214+
IndexSubset *
215+
autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) {
216+
auto originalFn = AFD->getInterfaceType()->castTo<AnyFunctionType>();
217+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
218+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
219+
auto numResults = semanticResults.size();
220+
return IndexSubset::getDefault(
221+
AFD->getASTContext(), numResults, /*includeAll*/ true);
222+
}
223+
214224
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
215225
IndexSubset *
216226
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,

lib/IRGen/IRGenMangler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ class IRGenMangler : public Mangle::ASTMangler {
5757
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
5858
beginManglingWithAutoDiffOriginalFunction(func);
5959
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
60+
auto *resultIndices =
61+
autodiff::getAllFunctionSemanticResultIndices(func);
6062
AutoDiffConfig config(
6163
derivativeId->getParameterIndices(),
62-
IndexSubset::get(func->getASTContext(), 1, {0}),
64+
resultIndices,
6365
derivativeId->getDerivativeGenericSignature());
6466
appendAutoDiffFunctionParts("TJ", kind, config);
6567
appendOperator("Tj");
@@ -86,9 +88,11 @@ class IRGenMangler : public Mangle::ASTMangler {
8688
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
8789
beginManglingWithAutoDiffOriginalFunction(func);
8890
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
91+
auto *resultIndices =
92+
autodiff::getAllFunctionSemanticResultIndices(func);
8993
AutoDiffConfig config(
9094
derivativeId->getParameterIndices(),
91-
IndexSubset::get(func->getASTContext(), 1, {0}),
95+
resultIndices,
9296
derivativeId->getDerivativeGenericSignature());
9397
appendAutoDiffFunctionParts("TJ", kind, config);
9498
appendOperator("Tq");

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -852,13 +852,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
852852
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
853853
derivativeFunctionIdentifier->getParameterIndices(),
854854
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
855-
auto originalFn =
856-
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
857-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
858-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
859-
auto numResults = semanticResults.size();
860-
auto *resultIndices = IndexSubset::getDefault(
861-
getDecl()->getASTContext(), numResults, /*includeAll*/ true);
855+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
856+
asAutoDiffOriginalFunction().getAbstractFunctionDecl());
862857
AutoDiffConfig silConfig(
863858
silParameterIndices, resultIndices,
864859
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/Sema/TypeCheckAttr.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5026,12 +5026,8 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
50265026
}
50275027
getterDecl->getAttrs().add(newAttr);
50285028
// Register derivative function configuration.
5029-
auto originalFn = getterDecl->getInterfaceType()->castTo<AnyFunctionType>();
5030-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5031-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5032-
auto numResults = semanticResults.size();
5033-
auto *resultIndices = IndexSubset::getDefault(
5034-
ctx, numResults, /*includeAll*/ true);
5029+
auto *resultIndices =
5030+
autodiff::getAllFunctionSemanticResultIndices(getterDecl);
50355031
getterDecl->addDerivativeFunctionConfiguration(
50365032
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
50375033
return resolvedDiffParamIndices;
@@ -5472,12 +5468,8 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
54725468
}
54735469

54745470
// Register derivative function configuration.
5475-
auto originalFn = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
5476-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5477-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5478-
auto numResults = semanticResults.size();
5479-
auto *resultIndices = IndexSubset::getDefault(
5480-
Ctx, numResults, /*includeAll*/ true);
5471+
auto *resultIndices =
5472+
autodiff::getAllFunctionSemanticResultIndices(originalAFD);
54815473
originalAFD->addDerivativeFunctionConfiguration(
54825474
{resolvedDiffParamIndices, resultIndices,
54835475
derivative->getGenericSignature()});

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
491491
witness->getAttrs().add(newAttr);
492492
success = true;
493493
// Register derivative function configuration.
494-
auto originalFn =
495-
witnessAFD->getInterfaceType()->castTo<AnyFunctionType>();
496-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
497-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
498-
auto numResults = semanticResults.size();
499-
auto *resultIndices = IndexSubset::getDefault(
500-
ctx, numResults, /*includeAll*/ true);
494+
auto *resultIndices =
495+
autodiff::getAllFunctionSemanticResultIndices(witnessAFD);
501496
witnessAFD->addDerivativeFunctionConfiguration(
502497
{newAttr->getParameterIndices(), resultIndices,
503498
newAttr->getDerivativeGenericSignature()});

lib/Serialization/ModuleFile.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,8 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
668668
auto derivativeGenSig = derivativeGenSigOrError.get();
669669
// NOTE(TF-1038): Result indices are currently unsupported in derivative
670670
// registration attributes. In the meantime, always use all results.
671-
auto originalFn =
672-
originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
673-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
674-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
675-
auto numResults = semanticResults.size();
676-
auto *resultIndices = IndexSubset::getDefault(
677-
ctx, numResults, /*includeAll*/ true);
671+
auto *resultIndices =
672+
autodiff::getAllFunctionSemanticResultIndices(originalAFD);
678673
results.insert({parameterIndices, resultIndices, derivativeGenSig});
679674
}
680675
}

0 commit comments

Comments
 (0)