Skip to content

Commit 79e9b2e

Browse files
committed
Reworking logic for non-wrt inout parameters. Replacing single result index with all result indices in multiple places.
1 parent 036a638 commit 79e9b2e

File tree

5 files changed

+86
-21
lines changed

5 files changed

+86
-21
lines changed

lib/AST/Type.cpp

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6376,11 +6376,14 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
63766376
bool hasInoutResult = false;
63776377
for (auto i : range(originalResults.size())) {
63786378
auto originalResult = originalResults[i];
6379+
auto originalResultType = originalResult.type;
6380+
// Voids currently have a defined tangent vector, so ignore them.
6381+
if (originalResultType->isVoid())
6382+
continue;
63796383
if (originalResult.isInout) {
63806384
hasInoutResult = true;
63816385
continue;
63826386
}
6383-
auto originalResultType = originalResult.type;
63846387
// Get the original semantic result type's `TangentVector` associated type.
63856388
auto resultTan =
63866389
originalResultType->getAutoDiffTangentSpace(lookupConformance);
@@ -6390,19 +6393,48 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
63906393
resultTanTypes.push_back(resultTanType);
63916394
}
63926395
// Append non-wrt inout result tangent spaces.
6393-
auto *resultFunctionType = this->getResult()->getAs<AnyFunctionType>();
6394-
auto sourceFunction = resultFunctionType ? resultFunctionType : this;
6395-
for (unsigned i : range(sourceFunction->getNumParams())) {
6396-
auto param = sourceFunction->getParams()[i];
6397-
if (parameterIndices->contains(i))
6398-
continue;
6399-
if (param.isInOut()) {
6400-
auto resultType = param.getPlainType();
6401-
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
6402-
if (!resultTan)
6396+
// This uses the logic from getSubsetParameters(), only operating over all
6397+
// parameter indices and looking for non-wrt indices.
6398+
SmallVector<AnyFunctionType *, 2> curryLevels;
6399+
// An inlined version of unwrapCurryLevels().
6400+
AnyFunctionType *fnTy = this;
6401+
while (fnTy != nullptr) {
6402+
curryLevels.push_back(fnTy);
6403+
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
6404+
}
6405+
6406+
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
6407+
unsigned currentOffset = 0;
6408+
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
6409+
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
6410+
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
6411+
}
6412+
6413+
if (!makeSelfParamFirst) {
6414+
std::reverse(curryLevels.begin(), curryLevels.end());
6415+
std::reverse(curryLevelParameterIndexOffsets.begin(),
6416+
curryLevelParameterIndexOffsets.end());
6417+
}
6418+
6419+
for (unsigned curryLevelIndex : indices(curryLevels)) {
6420+
auto *curryLevel = curryLevels[curryLevelIndex];
6421+
unsigned parameterIndexOffset =
6422+
curryLevelParameterIndexOffsets[curryLevelIndex];
6423+
for (unsigned paramIndex : range(curryLevel->getNumParams())) {
6424+
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
64036425
continue;
6404-
auto resultTanType = resultTan->getType();
6405-
resultTanTypes.push_back(resultTanType);
6426+
6427+
auto param = curryLevel->getParams()[paramIndex];
6428+
if (param.isInOut()) {
6429+
auto resultType = param.getPlainType();
6430+
if (resultType->isVoid())
6431+
continue;
6432+
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
6433+
if (!resultTan)
6434+
continue;
6435+
auto resultTanType = resultTan->getType();
6436+
resultTanTypes.push_back(resultTanType);
6437+
}
64066438
}
64076439
}
64086440

@@ -6497,6 +6529,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
64976529
std::make_pair(paramType, i));
64986530
}
64996531
if (diffParam.isInOut()) {
6532+
if (paramType->isVoid())
6533+
continue;
65006534
inoutParams.push_back(diffParam);
65016535
continue;
65026536
}

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,13 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
852852
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
853853
derivativeFunctionIdentifier->getParameterIndices(),
854854
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
855-
auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0});
855+
auto originalFn =
856+
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
857+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
858+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
859+
auto numResults = semanticResults.size();
860+
auto *resultIndices = IndexSubset::getDefault(
861+
getDecl()->getASTContext(), numResults, /*includeAll*/ true);
856862
AutoDiffConfig silConfig(
857863
silParameterIndices, resultIndices,
858864
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/Sema/TypeCheckAttr.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5026,7 +5026,12 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
50265026
}
50275027
getterDecl->getAttrs().add(newAttr);
50285028
// Register derivative function configuration.
5029-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
5029+
auto originalFn = getterDecl->getInterfaceType()->castTo<AnyFunctionType>();
5030+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5031+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5032+
auto numResults = semanticResults.size();
5033+
auto *resultIndices = IndexSubset::getDefault(
5034+
ctx, numResults, /*includeAll*/ true);
50305035
getterDecl->addDerivativeFunctionConfiguration(
50315036
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
50325037
return resolvedDiffParamIndices;
@@ -5041,7 +5046,11 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
50415046
return nullptr;
50425047
}
50435048
// Register derivative function configuration.
5044-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
5049+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5050+
autodiff::getFunctionSemanticResultTypes(originalFnRemappedTy, semanticResults);
5051+
auto numResults = semanticResults.size();
5052+
auto *resultIndices = IndexSubset::getDefault(
5053+
ctx, numResults, /*includeAll*/ true);
50455054
original->addDerivativeFunctionConfiguration(
50465055
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
50475056
return resolvedDiffParamIndices;
@@ -5463,7 +5472,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
54635472
}
54645473

54655474
// Register derivative function configuration.
5466-
auto *resultIndices = IndexSubset::get(Ctx, 1, {0});
5475+
auto originalFn = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
5476+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5477+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5478+
auto numResults = semanticResults.size();
5479+
auto *resultIndices = IndexSubset::getDefault(
5480+
Ctx, numResults, /*includeAll*/ true);
54675481
originalAFD->addDerivativeFunctionConfiguration(
54685482
{resolvedDiffParamIndices, resultIndices,
54695483
derivative->getGenericSignature()});

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,13 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
491491
witness->getAttrs().add(newAttr);
492492
success = true;
493493
// Register derivative function configuration.
494-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
494+
auto originalFn =
495+
witnessAFD->getInterfaceType()->castTo<AnyFunctionType>();
496+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
497+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
498+
auto numResults = semanticResults.size();
499+
auto *resultIndices = IndexSubset::getDefault(
500+
ctx, numResults, /*includeAll*/ true);
495501
witnessAFD->addDerivativeFunctionConfiguration(
496502
{newAttr->getParameterIndices(), resultIndices,
497503
newAttr->getDerivativeGenericSignature()});

lib/Serialization/ModuleFile.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,14 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
667667
}
668668
auto derivativeGenSig = derivativeGenSigOrError.get();
669669
// NOTE(TF-1038): Result indices are currently unsupported in derivative
670-
// registration attributes. In the meantime, always use `{0}` (wrt the
671-
// first and only result).
672-
auto resultIndices = IndexSubset::get(ctx, 1, {0});
670+
// registration attributes. In the meantime, always use all results.
671+
auto originalFn =
672+
originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
673+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
674+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
675+
auto numResults = semanticResults.size();
676+
auto *resultIndices = IndexSubset::getDefault(
677+
ctx, numResults, /*includeAll*/ true);
673678
results.insert({parameterIndices, resultIndices, derivativeGenSig});
674679
}
675680
}

0 commit comments

Comments
 (0)