Skip to content

Commit 0872802

Browse files
aslBradLarson
andcommitted
Add support for differentiable functions having multiple semantic results
Co-authored-by: Brad Larson <[email protected]>
1 parent effa462 commit 0872802

24 files changed

+391
-106
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
namespace swift {
3434

35+
class AbstractFunctionDecl;
3536
class AnyFunctionType;
3637
class SourceFile;
3738
class SILFunctionType;
@@ -398,9 +399,6 @@ class DerivativeFunctionTypeError
398399
enum class Kind {
399400
/// Original function type has no semantic results.
400401
NoSemanticResults,
401-
/// Original function type has multiple semantic results.
402-
// TODO(TF-1250): Support function types with multiple semantic results.
403-
MultipleSemanticResults,
404402
/// Differentiability parmeter indices are empty.
405403
NoDifferentiabilityParameters,
406404
/// A differentiability parameter does not conform to `Differentiable`.
@@ -429,7 +427,6 @@ class DerivativeFunctionTypeError
429427
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
430428
: functionType(functionType), kind(kind), value(Value()) {
431429
assert(kind == Kind::NoSemanticResults ||
432-
kind == Kind::MultipleSemanticResults ||
433430
kind == Kind::NoDifferentiabilityParameters);
434431
};
435432

@@ -579,6 +576,10 @@ void getFunctionSemanticResultTypes(
579576
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
580577
GenericEnvironment *genericEnv = nullptr);
581578

579+
/// Returns the indices of all semantic results for a given function.
580+
IndexSubset *getAllFunctionSemanticResultIndices(
581+
const AbstractFunctionDecl *AFD);
582+
582583
/// Returns the lowered SIL parameter indices for the given AST parameter
583584
/// indices and `AnyfunctionType`.
584585
///

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3950,9 +3950,6 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none,
39503950
(DescriptiveDeclKind))
39513951
ERROR(autodiff_attr_original_void_result,none,
39523952
"cannot differentiate void function %0", (DeclName))
3953-
ERROR(autodiff_attr_original_multiple_semantic_results,none,
3954-
"cannot differentiate functions with both an 'inout' parameter and a "
3955-
"result", ())
39563953
ERROR(autodiff_attr_result_not_differentiable,none,
39573954
"can only differentiate functions with results that conform to "
39583955
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))

lib/AST/AutoDiff.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,16 @@ void autodiff::getFunctionSemanticResultTypes(
196196
functionType->getResult()->getAs<AnyFunctionType>()) {
197197
formalResultType = resultFunctionType->getResult();
198198
}
199-
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
200-
result.push_back({remap(formalResultType), /*isInout*/ false});
199+
if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
200+
// Separate tuple elements into individual results.
201+
if (formalResultType->is<TupleType>()) {
202+
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
203+
result.push_back({remap(elt.getType()), /*isInout*/ false});
204+
}
205+
} else {
206+
result.push_back({remap(formalResultType), /*isInout*/ false});
207+
}
208+
}
201209

202210
// Collect `inout` parameters as semantic results.
203211
for (auto param : functionType->getParams())
@@ -211,6 +219,16 @@ void autodiff::getFunctionSemanticResultTypes(
211219
}
212220
}
213221

222+
IndexSubset *
223+
autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) {
224+
auto originalFn = AFD->getInterfaceType()->castTo<AnyFunctionType>();
225+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
226+
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
227+
auto numResults = semanticResults.size();
228+
return IndexSubset::getDefault(
229+
AFD->getASTContext(), numResults, /*includeAll*/ true);
230+
}
231+
214232
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
215233
IndexSubset *
216234
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,
@@ -395,9 +413,6 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
395413
case Kind::NoSemanticResults:
396414
OS << "has no semantic results ('Void' result)";
397415
break;
398-
case Kind::MultipleSemanticResults:
399-
OS << "has multiple semantic results";
400-
break;
401416
case Kind::NoDifferentiabilityParameters:
402417
OS << "has no differentiability parameters";
403418
break;

lib/AST/Type.cpp

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5548,31 +5548,86 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55485548
getSubsetParameters(parameterIndices, diffParams,
55495549
/*reverseCurryLevels*/ !makeSelfParamFirst);
55505550

5551-
// Get the original semantic result type.
5551+
// Get the original non-inout semantic result types.
55525552
SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
55535553
autodiff::getFunctionSemanticResultTypes(this, originalResults);
55545554
// Error if no original semantic results.
55555555
if (originalResults.empty())
55565556
return llvm::make_error<DerivativeFunctionTypeError>(
55575557
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
5558-
// Error if multiple original semantic results.
5559-
// TODO(TF-1250): Support functions with multiple semantic results.
5560-
if (originalResults.size() > 1)
5561-
return llvm::make_error<DerivativeFunctionTypeError>(
5562-
this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults);
5563-
auto originalResult = originalResults.front();
5564-
auto originalResultType = originalResult.type;
5565-
5566-
// Get the original semantic result type's `TangentVector` associated type.
5567-
auto resultTan =
5568-
originalResultType->getAutoDiffTangentSpace(lookupConformance);
5569-
// Error if original semantic result has no tangent space.
5570-
if (!resultTan) {
5558+
// Accumulate non-inout result tangent spaces.
5559+
SmallVector<Type, 1> resultTanTypes;
5560+
bool hasInoutResult = false;
5561+
for (auto i : range(originalResults.size())) {
5562+
auto originalResult = originalResults[i];
5563+
auto originalResultType = originalResult.type;
5564+
// Voids currently have a defined tangent vector, so ignore them.
5565+
if (originalResultType->isVoid())
5566+
continue;
5567+
if (originalResult.isInout) {
5568+
hasInoutResult = true;
5569+
continue;
5570+
}
5571+
// Get the original semantic result type's `TangentVector` associated type.
5572+
auto resultTan =
5573+
originalResultType->getAutoDiffTangentSpace(lookupConformance);
5574+
if (!resultTan)
5575+
continue;
5576+
auto resultTanType = resultTan->getType();
5577+
resultTanTypes.push_back(resultTanType);
5578+
}
5579+
// Append non-wrt inout result tangent spaces.
5580+
// This uses the logic from getSubsetParameters(), only operating over all
5581+
// parameter indices and looking for non-wrt indices.
5582+
SmallVector<AnyFunctionType *, 2> curryLevels;
5583+
// An inlined version of unwrapCurryLevels().
5584+
AnyFunctionType *fnTy = this;
5585+
while (fnTy != nullptr) {
5586+
curryLevels.push_back(fnTy);
5587+
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
5588+
}
5589+
5590+
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
5591+
unsigned currentOffset = 0;
5592+
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
5593+
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
5594+
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
5595+
}
5596+
5597+
if (!makeSelfParamFirst) {
5598+
std::reverse(curryLevels.begin(), curryLevels.end());
5599+
std::reverse(curryLevelParameterIndexOffsets.begin(),
5600+
curryLevelParameterIndexOffsets.end());
5601+
}
5602+
5603+
for (unsigned curryLevelIndex : indices(curryLevels)) {
5604+
auto *curryLevel = curryLevels[curryLevelIndex];
5605+
unsigned parameterIndexOffset =
5606+
curryLevelParameterIndexOffsets[curryLevelIndex];
5607+
for (unsigned paramIndex : range(curryLevel->getNumParams())) {
5608+
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
5609+
continue;
5610+
5611+
auto param = curryLevel->getParams()[paramIndex];
5612+
if (param.isInOut()) {
5613+
auto resultType = param.getPlainType();
5614+
if (resultType->isVoid())
5615+
continue;
5616+
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
5617+
if (!resultTan)
5618+
continue;
5619+
auto resultTanType = resultTan->getType();
5620+
resultTanTypes.push_back(resultTanType);
5621+
}
5622+
}
5623+
}
5624+
5625+
// Error if no semantic result has a tangent space.
5626+
if (resultTanTypes.empty() && !hasInoutResult) {
55715627
return llvm::make_error<DerivativeFunctionTypeError>(
55725628
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5573-
std::make_pair(originalResultType, /*index*/ 0));
5629+
std::make_pair(originalResults.front().type, /*index*/ 0));
55745630
}
5575-
auto resultTanType = resultTan->getType();
55765631

55775632
// Compute the result linear map function type.
55785633
FunctionType *linearMapType;
@@ -5592,7 +5647,6 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55925647
// - Original: `(T0, inout T1, ...) -> Void`
55935648
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
55945649
SmallVector<AnyFunctionType::Param, 4> differentialParams;
5595-
bool hasInoutDiffParameter = false;
55965650
for (auto i : range(diffParams.size())) {
55975651
auto diffParam = diffParams[i];
55985652
auto paramType = diffParam.getPlainType();
@@ -5607,11 +5661,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56075661
}
56085662
differentialParams.push_back(AnyFunctionType::Param(
56095663
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
5610-
if (diffParam.isInOut())
5611-
hasInoutDiffParameter = true;
56125664
}
5613-
auto differentialResult =
5614-
hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType;
5665+
Type differentialResult;
5666+
if (resultTanTypes.empty()) {
5667+
differentialResult = ctx.TheEmptyTupleType;
5668+
} else if (resultTanTypes.size() == 1) {
5669+
differentialResult = resultTanTypes.front();
5670+
} else {
5671+
SmallVector<TupleTypeElt, 2> differentialResults;
5672+
for (auto i : range(resultTanTypes.size())) {
5673+
auto resultTanType = resultTanTypes[i];
5674+
differentialResults.push_back(
5675+
TupleTypeElt(resultTanType, Identifier()));
5676+
}
5677+
differentialResult = TupleType::get(differentialResults, ctx);
5678+
}
5679+
56155680
// FIXME: Verify ExtInfo state is correct, not working by accident.
56165681
FunctionType::ExtInfo info;
56175682
linearMapType =
@@ -5629,11 +5694,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56295694
// - Original: `(T0, inout T1, ...) -> Void`
56305695
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
56315696
//
5632-
// Case 3: original function has a wrt `inout` parameter.
5633-
// - Original: `(T0, inout T1, ...) -> Void`
5634-
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
5697+
// Case 3: original function has wrt `inout` parameters.
5698+
// - Original: `(T0, inout T1, ...) -> R`
5699+
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
56355700
SmallVector<TupleTypeElt, 4> pullbackResults;
5636-
bool hasInoutDiffParameter = false;
5701+
SmallVector<AnyFunctionType::Param, 2> inoutParams;
56375702
for (auto i : range(diffParams.size())) {
56385703
auto diffParam = diffParams[i];
56395704
auto paramType = diffParam.getPlainType();
@@ -5647,7 +5712,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56475712
std::make_pair(paramType, i));
56485713
}
56495714
if (diffParam.isInOut()) {
5650-
hasInoutDiffParameter = true;
5715+
if (paramType->isVoid())
5716+
continue;
5717+
inoutParams.push_back(diffParam);
56515718
continue;
56525719
}
56535720
pullbackResults.emplace_back(paramTan->getType());
@@ -5660,12 +5727,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56605727
} else {
56615728
pullbackResult = TupleType::get(pullbackResults, ctx);
56625729
}
5663-
auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter);
5664-
auto pullbackParam =
5665-
AnyFunctionType::Param(resultTanType, Identifier(), flags);
5730+
// First accumulate non-inout results as pullback parameters.
5731+
SmallVector<FunctionType::Param, 2> pullbackParams;
5732+
for (auto i : range(resultTanTypes.size())) {
5733+
auto resultTanType = resultTanTypes[i];
5734+
auto flags = ParameterTypeFlags().withInOut(false);
5735+
pullbackParams.push_back(AnyFunctionType::Param(
5736+
resultTanType, Identifier(), flags));
5737+
}
5738+
// Then append inout parameters.
5739+
for (auto i : range(inoutParams.size())) {
5740+
auto inoutParam = inoutParams[i];
5741+
auto inoutParamType = inoutParam.getPlainType();
5742+
auto inoutParamTan =
5743+
inoutParamType->getAutoDiffTangentSpace(lookupConformance);
5744+
auto flags = ParameterTypeFlags().withInOut(true);
5745+
pullbackParams.push_back(AnyFunctionType::Param(
5746+
inoutParamTan->getType(), Identifier(), flags));
5747+
}
56665748
// FIXME: Verify ExtInfo state is correct, not working by accident.
56675749
FunctionType::ExtInfo info;
5668-
linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info);
5750+
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
56695751
break;
56705752
}
56715753
}

lib/IRGen/IRGenMangler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ class IRGenMangler : public Mangle::ASTMangler {
5757
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
5858
beginManglingWithAutoDiffOriginalFunction(func);
5959
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
60+
auto *resultIndices =
61+
autodiff::getAllFunctionSemanticResultIndices(func);
6062
AutoDiffConfig config(
6163
derivativeId->getParameterIndices(),
62-
IndexSubset::get(func->getASTContext(), 1, {0}),
64+
resultIndices,
6365
derivativeId->getDerivativeGenericSignature());
6466
appendAutoDiffFunctionParts("TJ", kind, config);
6567
appendOperator("Tj");
@@ -86,9 +88,11 @@ class IRGenMangler : public Mangle::ASTMangler {
8688
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
8789
beginManglingWithAutoDiffOriginalFunction(func);
8890
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
91+
auto *resultIndices =
92+
autodiff::getAllFunctionSemanticResultIndices(func);
8993
AutoDiffConfig config(
9094
derivativeId->getParameterIndices(),
91-
IndexSubset::get(func->getASTContext(), 1, {0}),
95+
resultIndices,
9296
derivativeId->getDerivativeGenericSignature());
9397
appendAutoDiffFunctionParts("TJ", kind, config);
9498
appendOperator("Tq");

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
10671067
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
10681068
derivativeFunctionIdentifier->getParameterIndices(),
10691069
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
1070-
auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0});
1070+
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
1071+
asAutoDiffOriginalFunction().getAbstractFunctionDecl());
10711072
AutoDiffConfig silConfig(
10721073
silParameterIndices, resultIndices,
10731074
derivativeFunctionIdentifier->getDerivativeGenericSignature());

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,6 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
238238
resultIndices.push_back(resultAndIndex.index());
239239
// Check `inout` parameters.
240240
for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters()))
241-
// FIXME(TF-1305): The `getResults().empty()` condition is a hack.
242-
//
243241
// Currently, an `inout` parameter can either be:
244242
// 1. Both a differentiability parameter and a differentiability result.
245243
// 2. `@noDerivative`: neither a differentiability parameter nor a
@@ -251,13 +249,8 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
251249
// cases, so supporting it is a non-goal.
252250
//
253251
// See TF-1305 for solution ideas. For now, `@noDerivative` `inout`
254-
// parameters are not treated as differentiability results, unless the
255-
// original function has no formal results, in which case all `inout`
256252
// parameters are treated as differentiability results.
257-
if (getResults().empty() ||
258-
inoutParamAndIndex.value().getDifferentiability() !=
259-
SILParameterDifferentiability::NotDifferentiable)
260-
resultIndices.push_back(getNumResults() + inoutParamAndIndex.index());
253+
resultIndices.push_back(getNumResults() + inoutParamAndIndex.index());
261254
auto numSemanticResults =
262255
getNumResults() + getNumIndirectMutatingParameters();
263256
return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices);
@@ -603,7 +596,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
603596
differentialResults.push_back({resultTanType, resultConv});
604597
continue;
605598
}
606-
// Handle original `inout` parameter.
599+
// Handle original `inout` parameters.
607600
auto inoutParamIndex = resultIndex - originalFnTy->getNumResults();
608601
auto inoutParamIt = std::next(
609602
originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
@@ -745,7 +738,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
745738
pullbackParams.push_back({resultTanType, paramConv});
746739
continue;
747740
}
748-
// Handle original `inout` parameter.
741+
// Handle `inout` parameters.
749742
auto inoutParamIndex = resultIndex - originalFnTy->getNumResults();
750743
auto inoutParamIt = std::next(
751744
originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);

0 commit comments

Comments
 (0)