Skip to content

Commit ff8dc58

Browse files
committed
Consolidated repeated result index generation into a central function.
1 parent f5d4885 commit ff8dc58

File tree

7 files changed

+31
-35
lines changed

7 files changed

+31
-35
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
namespace swift {
3434

35+
class AbstractFunctionDecl;
3536
class AnyFunctionType;
3637
class SourceFile;
3738
class SILFunctionType;
@@ -575,6 +576,10 @@ void getFunctionSemanticResultTypes(
575576
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
576577
GenericEnvironment *genericEnv = nullptr);
577578

579+
/// Returns the indices of all semantic results for a given function.
580+
IndexSubset *getAllFunctionSemanticResultIndices(
581+
const AbstractFunctionDecl *AFD);
582+
578583
/// Returns the lowered SIL parameter indices for the given AST parameter
579584
/// indices and `AnyfunctionType`.
580585
///

lib/AST/AutoDiff.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ void autodiff::getFunctionSemanticResultTypes(
211211
}
212212
}
213213

214+
IndexSubset *
215+
autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) {
216+
auto originalFn = AFD->getInterfaceType()->castTo<AnyFunctionType>();
217+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
218+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
219+
auto numResults = semanticResults.size();
220+
return IndexSubset::getDefault(
221+
AFD->getASTContext(), numResults, /*includeAll*/ true);
222+
}
223+
214224
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
215225
IndexSubset *
216226
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,

lib/IRGen/IRGenMangler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ class IRGenMangler : public Mangle::ASTMangler {
5757
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
5858
beginManglingWithAutoDiffOriginalFunction(func);
5959
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
60+
auto *resultIndices =
61+
autodiff::getAllFunctionSemanticResultIndices(func);
6062
AutoDiffConfig config(
6163
derivativeId->getParameterIndices(),
62-
IndexSubset::get(func->getASTContext(), 1, {0}),
64+
resultIndices,
6365
derivativeId->getDerivativeGenericSignature());
6466
appendAutoDiffFunctionParts("TJ", kind, config);
6567
appendOperator("Tj");
@@ -86,9 +88,11 @@ class IRGenMangler : public Mangle::ASTMangler {
8688
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
8789
beginManglingWithAutoDiffOriginalFunction(func);
8890
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
91+
auto *resultIndices =
92+
autodiff::getAllFunctionSemanticResultIndices(func);
8993
AutoDiffConfig config(
9094
derivativeId->getParameterIndices(),
91-
IndexSubset::get(func->getASTContext(), 1, {0}),
95+
resultIndices,
9296
derivativeId->getDerivativeGenericSignature());
9397
appendAutoDiffFunctionParts("TJ", kind, config);
9498
appendOperator("Tq");

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -853,13 +853,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
853853
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
854854
derivativeFunctionIdentifier->getParameterIndices(),
855855
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
856-
auto originalFn =
857-
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
858-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
859-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
860-
auto numResults = semanticResults.size();
861-
auto *resultIndices = IndexSubset::getDefault(
862-
getDecl()->getASTContext(), numResults, /*includeAll*/ true);
856+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
857+
asAutoDiffOriginalFunction().getAbstractFunctionDecl());
863858
AutoDiffConfig silConfig(
864859
silParameterIndices, resultIndices,
865860
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/Sema/TypeCheckAttr.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5073,12 +5073,8 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
50735073
}
50745074
getterDecl->getAttrs().add(newAttr);
50755075
// Register derivative function configuration.
5076-
auto originalFn = getterDecl->getInterfaceType()->castTo<AnyFunctionType>();
5077-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5078-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5079-
auto numResults = semanticResults.size();
5080-
auto *resultIndices = IndexSubset::getDefault(
5081-
ctx, numResults, /*includeAll*/ true);
5076+
auto *resultIndices =
5077+
autodiff::getAllFunctionSemanticResultIndices(getterDecl);
50825078
getterDecl->addDerivativeFunctionConfiguration(
50835079
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
50845080
return resolvedDiffParamIndices;
@@ -5519,12 +5515,8 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
55195515
}
55205516

55215517
// Register derivative function configuration.
5522-
auto originalFn = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
5523-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5524-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5525-
auto numResults = semanticResults.size();
5526-
auto *resultIndices = IndexSubset::getDefault(
5527-
Ctx, numResults, /*includeAll*/ true);
5518+
auto *resultIndices =
5519+
autodiff::getAllFunctionSemanticResultIndices(originalAFD);
55285520
originalAFD->addDerivativeFunctionConfiguration(
55295521
{resolvedDiffParamIndices, resultIndices,
55305522
derivative->getGenericSignature()});

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
490490
witness->getAttrs().add(newAttr);
491491
success = true;
492492
// Register derivative function configuration.
493-
auto originalFn =
494-
witnessAFD->getInterfaceType()->castTo<AnyFunctionType>();
495-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
496-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
497-
auto numResults = semanticResults.size();
498-
auto *resultIndices = IndexSubset::getDefault(
499-
ctx, numResults, /*includeAll*/ true);
493+
auto *resultIndices =
494+
autodiff::getAllFunctionSemanticResultIndices(witnessAFD);
500495
witnessAFD->addDerivativeFunctionConfiguration(
501496
{newAttr->getParameterIndices(), resultIndices,
502497
newAttr->getDerivativeGenericSignature()});

lib/Serialization/ModuleFile.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,8 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
668668
auto derivativeGenSig = derivativeGenSigOrError.get();
669669
// NOTE(TF-1038): Result indices are currently unsupported in derivative
670670
// registration attributes. In the meantime, always use all results.
671-
auto originalFn =
672-
originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
673-
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
674-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
675-
auto numResults = semanticResults.size();
676-
auto *resultIndices = IndexSubset::getDefault(
677-
ctx, numResults, /*includeAll*/ true);
671+
auto *resultIndices =
672+
autodiff::getAllFunctionSemanticResultIndices(originalAFD);
678673
results.insert({parameterIndices, resultIndices, derivativeGenSig});
679674
}
680675
}

0 commit comments

Comments
 (0)