Skip to content

Commit da7a037

Browse files
committed
Reworking logic for non-wrt inout parameters. Replacing single result index with all result indices in multiple places.
1 parent 36be38e commit da7a037

File tree

5 files changed

+86
-21
lines changed

5 files changed

+86
-21
lines changed

lib/AST/Type.cpp

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5482,11 +5482,14 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
54825482
bool hasInoutResult = false;
54835483
for (auto i : range(originalResults.size())) {
54845484
auto originalResult = originalResults[i];
5485+
auto originalResultType = originalResult.type;
5486+
// Voids currently have a defined tangent vector, so ignore them.
5487+
if (originalResultType->isVoid())
5488+
continue;
54855489
if (originalResult.isInout) {
54865490
hasInoutResult = true;
54875491
continue;
54885492
}
5489-
auto originalResultType = originalResult.type;
54905493
// Get the original semantic result type's `TangentVector` associated type.
54915494
auto resultTan =
54925495
originalResultType->getAutoDiffTangentSpace(lookupConformance);
@@ -5496,19 +5499,48 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
54965499
resultTanTypes.push_back(resultTanType);
54975500
}
54985501
// Append non-wrt inout result tangent spaces.
5499-
auto *resultFunctionType = this->getResult()->getAs<AnyFunctionType>();
5500-
auto sourceFunction = resultFunctionType ? resultFunctionType : this;
5501-
for (unsigned i : range(sourceFunction->getNumParams())) {
5502-
auto param = sourceFunction->getParams()[i];
5503-
if (parameterIndices->contains(i))
5504-
continue;
5505-
if (param.isInOut()) {
5506-
auto resultType = param.getPlainType();
5507-
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
5508-
if (!resultTan)
5502+
// This uses the logic from getSubsetParameters(), only operating over all
5503+
// parameter indices and looking for non-wrt indices.
5504+
SmallVector<AnyFunctionType *, 2> curryLevels;
5505+
// An inlined version of unwrapCurryLevels().
5506+
AnyFunctionType *fnTy = this;
5507+
while (fnTy != nullptr) {
5508+
curryLevels.push_back(fnTy);
5509+
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
5510+
}
5511+
5512+
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
5513+
unsigned currentOffset = 0;
5514+
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
5515+
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
5516+
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
5517+
}
5518+
5519+
if (!makeSelfParamFirst) {
5520+
std::reverse(curryLevels.begin(), curryLevels.end());
5521+
std::reverse(curryLevelParameterIndexOffsets.begin(),
5522+
curryLevelParameterIndexOffsets.end());
5523+
}
5524+
5525+
for (unsigned curryLevelIndex : indices(curryLevels)) {
5526+
auto *curryLevel = curryLevels[curryLevelIndex];
5527+
unsigned parameterIndexOffset =
5528+
curryLevelParameterIndexOffsets[curryLevelIndex];
5529+
for (unsigned paramIndex : range(curryLevel->getNumParams())) {
5530+
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
55095531
continue;
5510-
auto resultTanType = resultTan->getType();
5511-
resultTanTypes.push_back(resultTanType);
5532+
5533+
auto param = curryLevel->getParams()[paramIndex];
5534+
if (param.isInOut()) {
5535+
auto resultType = param.getPlainType();
5536+
if (resultType->isVoid())
5537+
continue;
5538+
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
5539+
if (!resultTan)
5540+
continue;
5541+
auto resultTanType = resultTan->getType();
5542+
resultTanTypes.push_back(resultTanType);
5543+
}
55125544
}
55135545
}
55145546

@@ -5603,6 +5635,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56035635
std::make_pair(paramType, i));
56045636
}
56055637
if (diffParam.isInOut()) {
5638+
if (paramType->isVoid())
5639+
continue;
56065640
inoutParams.push_back(diffParam);
56075641
continue;
56085642
}

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,13 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
781781
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
782782
derivativeFunctionIdentifier->getParameterIndices(),
783783
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
784-
auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0});
784+
auto originalFn =
785+
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
786+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
787+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
788+
auto numResults = semanticResults.size();
789+
auto *resultIndices = IndexSubset::getDefault(
790+
getDecl()->getASTContext(), numResults, /*includeAll*/ true);
785791
AutoDiffConfig silConfig(
786792
silParameterIndices, resultIndices,
787793
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/Sema/TypeCheckAttr.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4551,7 +4551,12 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
45514551
}
45524552
getterDecl->getAttrs().add(newAttr);
45534553
// Register derivative function configuration.
4554-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
4554+
auto originalFn = getterDecl->getInterfaceType()->castTo<AnyFunctionType>();
4555+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
4556+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
4557+
auto numResults = semanticResults.size();
4558+
auto *resultIndices = IndexSubset::getDefault(
4559+
ctx, numResults, /*includeAll*/ true);
45554560
getterDecl->addDerivativeFunctionConfiguration(
45564561
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
45574562
return resolvedDiffParamIndices;
@@ -4566,7 +4571,11 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
45664571
return nullptr;
45674572
}
45684573
// Register derivative function configuration.
4569-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
4574+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
4575+
autodiff::getFunctionSemanticResultTypes(originalFnRemappedTy, semanticResults);
4576+
auto numResults = semanticResults.size();
4577+
auto *resultIndices = IndexSubset::getDefault(
4578+
ctx, numResults, /*includeAll*/ true);
45704579
original->addDerivativeFunctionConfiguration(
45714580
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
45724581
return resolvedDiffParamIndices;
@@ -5005,7 +5014,12 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
50055014
}
50065015

50075016
// Register derivative function configuration.
5008-
auto *resultIndices = IndexSubset::get(Ctx, 1, {0});
5017+
auto originalFn = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
5018+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5019+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5020+
auto numResults = semanticResults.size();
5021+
auto *resultIndices = IndexSubset::getDefault(
5022+
Ctx, numResults, /*includeAll*/ true);
50095023
originalAFD->addDerivativeFunctionConfiguration(
50105024
{resolvedDiffParamIndices, resultIndices,
50115025
derivative->getGenericSignature()});

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,13 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
457457
witness->getAttrs().add(newAttr);
458458
success = true;
459459
// Register derivative function configuration.
460-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
460+
auto originalFn =
461+
witnessAFD->getInterfaceType()->castTo<AnyFunctionType>();
462+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
463+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
464+
auto numResults = semanticResults.size();
465+
auto *resultIndices = IndexSubset::getDefault(
466+
ctx, numResults, /*includeAll*/ true);
461467
witnessAFD->addDerivativeFunctionConfiguration(
462468
{newAttr->getParameterIndices(), resultIndices,
463469
newAttr->getDerivativeGenericSignature()});

lib/Serialization/ModuleFile.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,14 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
654654
}
655655
auto derivativeGenSig = derivativeGenSigOrError.get();
656656
// NOTE(TF-1038): Result indices are currently unsupported in derivative
657-
// registration attributes. In the meantime, always use `{0}` (wrt the
658-
// first and only result).
659-
auto resultIndices = IndexSubset::get(ctx, 1, {0});
657+
// registration attributes. In the meantime, always use all results.
658+
auto originalFn =
659+
originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
660+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
661+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
662+
auto numResults = semanticResults.size();
663+
auto *resultIndices = IndexSubset::getDefault(
664+
ctx, numResults, /*includeAll*/ true);
660665
results.insert({parameterIndices, resultIndices, derivativeGenSig});
661666
}
662667
}

0 commit comments

Comments
 (0)