Skip to content

Commit bb743f8

Browse files
committed
Consolidated repeated result index generation into a central function.
1 parent da7a037 commit bb743f8

File tree

7 files changed

+31
-35
lines changed

7 files changed

+31
-35
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
namespace swift {
3434

35+
class AbstractFunctionDecl;
3536
class AnyFunctionType;
3637
class SourceFile;
3738
class SILFunctionType;
@@ -575,6 +576,10 @@ void getFunctionSemanticResultTypes(
575576
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
576577
GenericEnvironment *genericEnv = nullptr);
577578

579+
/// Returns the indices of all semantic results for a given function.
580+
IndexSubset *getAllFunctionSemanticResultIndices(
581+
const AbstractFunctionDecl *AFD);
582+
578583
/// Returns the lowered SIL parameter indices for the given AST parameter
579584
/// indices and `AnyfunctionType`.
580585
///

lib/AST/AutoDiff.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ void autodiff::getFunctionSemanticResultTypes(
211211
}
212212
}
213213

214+
IndexSubset *
215+
autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) {
216+
auto originalFn = AFD->getInterfaceType()->castTo<AnyFunctionType>();
217+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
218+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
219+
auto numResults = semanticResults.size();
220+
return IndexSubset::getDefault(
221+
AFD->getASTContext(), numResults, /*includeAll*/ true);
222+
}
223+
214224
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
215225
IndexSubset *
216226
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,

lib/IRGen/IRGenMangler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ class IRGenMangler : public Mangle::ASTMangler {
5757
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
5858
beginManglingWithAutoDiffOriginalFunction(func);
5959
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
60+
auto *resultIndices =
61+
autodiff::getAllFunctionSemanticResultIndices(func);
6062
AutoDiffConfig config(
6163
derivativeId->getParameterIndices(),
62-
IndexSubset::get(func->getASTContext(), 1, {0}),
64+
resultIndices,
6365
derivativeId->getDerivativeGenericSignature());
6466
appendAutoDiffFunctionParts("TJ", kind, config);
6567
appendOperator("Tj");
@@ -86,9 +88,11 @@ class IRGenMangler : public Mangle::ASTMangler {
8688
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
8789
beginManglingWithAutoDiffOriginalFunction(func);
8890
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
91+
auto *resultIndices =
92+
autodiff::getAllFunctionSemanticResultIndices(func);
8993
AutoDiffConfig config(
9094
derivativeId->getParameterIndices(),
91-
IndexSubset::get(func->getASTContext(), 1, {0}),
95+
resultIndices,
9296
derivativeId->getDerivativeGenericSignature());
9397
appendAutoDiffFunctionParts("TJ", kind, config);
9498
appendOperator("Tq");

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -781,13 +781,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
781781
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
782782
derivativeFunctionIdentifier->getParameterIndices(),
783783
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
784-
auto originalFn =
785-
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
786-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
787-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
788-
auto numResults = semanticResults.size();
789-
auto *resultIndices = IndexSubset::getDefault(
790-
getDecl()->getASTContext(), numResults, /*includeAll*/ true);
784+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
785+
asAutoDiffOriginalFunction().getAbstractFunctionDecl());
791786
AutoDiffConfig silConfig(
792787
silParameterIndices, resultIndices,
793788
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/Sema/TypeCheckAttr.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4551,12 +4551,8 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
45514551
}
45524552
getterDecl->getAttrs().add(newAttr);
45534553
// Register derivative function configuration.
4554-
auto originalFn = getterDecl->getInterfaceType()->castTo<AnyFunctionType>();
4555-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
4556-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
4557-
auto numResults = semanticResults.size();
4558-
auto *resultIndices = IndexSubset::getDefault(
4559-
ctx, numResults, /*includeAll*/ true);
4554+
auto *resultIndices =
4555+
autodiff::getAllFunctionSemanticResultIndices(getterDecl);
45604556
getterDecl->addDerivativeFunctionConfiguration(
45614557
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
45624558
return resolvedDiffParamIndices;
@@ -5014,12 +5010,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
50145010
}
50155011

50165012
// Register derivative function configuration.
5017-
auto originalFn = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
5018-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5019-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5020-
auto numResults = semanticResults.size();
5021-
auto *resultIndices = IndexSubset::getDefault(
5022-
Ctx, numResults, /*includeAll*/ true);
5013+
auto *resultIndices =
5014+
autodiff::getAllFunctionSemanticResultIndices(originalAFD);
50235015
originalAFD->addDerivativeFunctionConfiguration(
50245016
{resolvedDiffParamIndices, resultIndices,
50255017
derivative->getGenericSignature()});

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
457457
witness->getAttrs().add(newAttr);
458458
success = true;
459459
// Register derivative function configuration.
460-
auto originalFn =
461-
witnessAFD->getInterfaceType()->castTo<AnyFunctionType>();
462-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
463-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
464-
auto numResults = semanticResults.size();
465-
auto *resultIndices = IndexSubset::getDefault(
466-
ctx, numResults, /*includeAll*/ true);
460+
auto *resultIndices =
461+
autodiff::getAllFunctionSemanticResultIndices(witnessAFD);
467462
witnessAFD->addDerivativeFunctionConfiguration(
468463
{newAttr->getParameterIndices(), resultIndices,
469464
newAttr->getDerivativeGenericSignature()});

lib/Serialization/ModuleFile.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -655,13 +655,8 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
655655
auto derivativeGenSig = derivativeGenSigOrError.get();
656656
// NOTE(TF-1038): Result indices are currently unsupported in derivative
657657
// registration attributes. In the meantime, always use all results.
658-
auto originalFn =
659-
originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
660-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
661-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
662-
auto numResults = semanticResults.size();
663-
auto *resultIndices = IndexSubset::getDefault(
664-
ctx, numResults, /*includeAll*/ true);
658+
auto *resultIndices =
659+
autodiff::getAllFunctionSemanticResultIndices(originalAFD);
665660
results.insert({parameterIndices, resultIndices, derivativeGenSig});
666661
}
667662
}

0 commit comments

Comments
 (0)