Skip to content

Commit f5d4885

Browse files
committed
Reworking logic for non-wrt inout parameters. Replacing single result index with all result indices in multiple places.
1 parent 2420dbc commit f5d4885

File tree

5 files changed

+86
-21
lines changed

5 files changed

+86
-21
lines changed

lib/AST/Type.cpp

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6444,11 +6444,14 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
64446444
bool hasInoutResult = false;
64456445
for (auto i : range(originalResults.size())) {
64466446
auto originalResult = originalResults[i];
6447+
auto originalResultType = originalResult.type;
6448+
// Voids currently have a defined tangent vector, so ignore them.
6449+
if (originalResultType->isVoid())
6450+
continue;
64476451
if (originalResult.isInout) {
64486452
hasInoutResult = true;
64496453
continue;
64506454
}
6451-
auto originalResultType = originalResult.type;
64526455
// Get the original semantic result type's `TangentVector` associated type.
64536456
auto resultTan =
64546457
originalResultType->getAutoDiffTangentSpace(lookupConformance);
@@ -6458,19 +6461,48 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
64586461
resultTanTypes.push_back(resultTanType);
64596462
}
64606463
// Append non-wrt inout result tangent spaces.
6461-
auto *resultFunctionType = this->getResult()->getAs<AnyFunctionType>();
6462-
auto sourceFunction = resultFunctionType ? resultFunctionType : this;
6463-
for (unsigned i : range(sourceFunction->getNumParams())) {
6464-
auto param = sourceFunction->getParams()[i];
6465-
if (parameterIndices->contains(i))
6466-
continue;
6467-
if (param.isInOut()) {
6468-
auto resultType = param.getPlainType();
6469-
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
6470-
if (!resultTan)
6464+
// This uses the logic from getSubsetParameters(), only operating over all
6465+
// parameter indices and looking for non-wrt indices.
6466+
SmallVector<AnyFunctionType *, 2> curryLevels;
6467+
// An inlined version of unwrapCurryLevels().
6468+
AnyFunctionType *fnTy = this;
6469+
while (fnTy != nullptr) {
6470+
curryLevels.push_back(fnTy);
6471+
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
6472+
}
6473+
6474+
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
6475+
unsigned currentOffset = 0;
6476+
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
6477+
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
6478+
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
6479+
}
6480+
6481+
if (!makeSelfParamFirst) {
6482+
std::reverse(curryLevels.begin(), curryLevels.end());
6483+
std::reverse(curryLevelParameterIndexOffsets.begin(),
6484+
curryLevelParameterIndexOffsets.end());
6485+
}
6486+
6487+
for (unsigned curryLevelIndex : indices(curryLevels)) {
6488+
auto *curryLevel = curryLevels[curryLevelIndex];
6489+
unsigned parameterIndexOffset =
6490+
curryLevelParameterIndexOffsets[curryLevelIndex];
6491+
for (unsigned paramIndex : range(curryLevel->getNumParams())) {
6492+
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
64716493
continue;
6472-
auto resultTanType = resultTan->getType();
6473-
resultTanTypes.push_back(resultTanType);
6494+
6495+
auto param = curryLevel->getParams()[paramIndex];
6496+
if (param.isInOut()) {
6497+
auto resultType = param.getPlainType();
6498+
if (resultType->isVoid())
6499+
continue;
6500+
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
6501+
if (!resultTan)
6502+
continue;
6503+
auto resultTanType = resultTan->getType();
6504+
resultTanTypes.push_back(resultTanType);
6505+
}
64746506
}
64756507
}
64766508

@@ -6565,6 +6597,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
65656597
std::make_pair(paramType, i));
65666598
}
65676599
if (diffParam.isInOut()) {
6600+
if (paramType->isVoid())
6601+
continue;
65686602
inoutParams.push_back(diffParam);
65696603
continue;
65706604
}

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,13 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
853853
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
854854
derivativeFunctionIdentifier->getParameterIndices(),
855855
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
856-
auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0});
856+
auto originalFn =
857+
getDecl()->getInterfaceType()->castTo<AnyFunctionType>();
858+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
859+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
860+
auto numResults = semanticResults.size();
861+
auto *resultIndices = IndexSubset::getDefault(
862+
getDecl()->getASTContext(), numResults, /*includeAll*/ true);
857863
AutoDiffConfig silConfig(
858864
silParameterIndices, resultIndices,
859865
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/Sema/TypeCheckAttr.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5073,7 +5073,12 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
50735073
}
50745074
getterDecl->getAttrs().add(newAttr);
50755075
// Register derivative function configuration.
5076-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
5076+
auto originalFn = getterDecl->getInterfaceType()->castTo<AnyFunctionType>();
5077+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5078+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5079+
auto numResults = semanticResults.size();
5080+
auto *resultIndices = IndexSubset::getDefault(
5081+
ctx, numResults, /*includeAll*/ true);
50775082
getterDecl->addDerivativeFunctionConfiguration(
50785083
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
50795084
return resolvedDiffParamIndices;
@@ -5088,7 +5093,11 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
50885093
return nullptr;
50895094
}
50905095
// Register derivative function configuration.
5091-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
5096+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5097+
autodiff::getFunctionSemanticResultTypes(originalFnRemappedTy, semanticResults);
5098+
auto numResults = semanticResults.size();
5099+
auto *resultIndices = IndexSubset::getDefault(
5100+
ctx, numResults, /*includeAll*/ true);
50925101
original->addDerivativeFunctionConfiguration(
50935102
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
50945103
return resolvedDiffParamIndices;
@@ -5510,7 +5519,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
55105519
}
55115520

55125521
// Register derivative function configuration.
5513-
auto *resultIndices = IndexSubset::get(Ctx, 1, {0});
5522+
auto originalFn = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
5523+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
5524+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
5525+
auto numResults = semanticResults.size();
5526+
auto *resultIndices = IndexSubset::getDefault(
5527+
Ctx, numResults, /*includeAll*/ true);
55145528
originalAFD->addDerivativeFunctionConfiguration(
55155529
{resolvedDiffParamIndices, resultIndices,
55165530
derivative->getGenericSignature()});

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,13 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
490490
witness->getAttrs().add(newAttr);
491491
success = true;
492492
// Register derivative function configuration.
493-
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
493+
auto originalFn =
494+
witnessAFD->getInterfaceType()->castTo<AnyFunctionType>();
495+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
496+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
497+
auto numResults = semanticResults.size();
498+
auto *resultIndices = IndexSubset::getDefault(
499+
ctx, numResults, /*includeAll*/ true);
494500
witnessAFD->addDerivativeFunctionConfiguration(
495501
{newAttr->getParameterIndices(), resultIndices,
496502
newAttr->getDerivativeGenericSignature()});

lib/Serialization/ModuleFile.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,14 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
667667
}
668668
auto derivativeGenSig = derivativeGenSigOrError.get();
669669
// NOTE(TF-1038): Result indices are currently unsupported in derivative
670-
// registration attributes. In the meantime, always use `{0}` (wrt the
671-
// first and only result).
672-
auto resultIndices = IndexSubset::get(ctx, 1, {0});
670+
// registration attributes. In the meantime, always use all results.
671+
auto originalFn =
672+
originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
673+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
674+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
675+
auto numResults = semanticResults.size();
676+
auto *resultIndices = IndexSubset::getDefault(
677+
ctx, numResults, /*includeAll*/ true);
673678
results.insert({parameterIndices, resultIndices, derivativeGenSig});
674679
}
675680
}

0 commit comments

Comments
 (0)