Skip to content

Commit eb82df6

Browse files
aslBradLarson
andauthored
[AutoDiff] Support differentiable functions with multiple semantic results (#66873)
Add support for differentiable functions having multiple semantic results Co-authored-by: Brad Larson <[email protected]>
1 parent 29ce7a3 commit eb82df6

26 files changed

+616
-173
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
namespace swift {
3434

35+
class AbstractFunctionDecl;
3536
class AnyFunctionType;
3637
class SourceFile;
3738
class SILFunctionType;
@@ -249,7 +250,12 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
249250
/// an `inout` parameter type. Used in derivative function type calculation.
250251
struct AutoDiffSemanticFunctionResultType {
251252
Type type;
252-
bool isInout;
253+
unsigned index : 30;
254+
bool isInout : 1;
255+
bool isWrtParam : 1;
256+
257+
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt)
258+
: type(t), index(idx), isInout(inout), isWrtParam(wrt) { }
253259
};
254260

255261
/// Key for caching SIL derivative function types.
@@ -400,9 +406,6 @@ class DerivativeFunctionTypeError
400406
enum class Kind {
401407
/// Original function type has no semantic results.
402408
NoSemanticResults,
403-
/// Original function type has multiple semantic results.
404-
// TODO(TF-1250): Support function types with multiple semantic results.
405-
MultipleSemanticResults,
406409
/// Differentiability parmeter indices are empty.
407410
NoDifferentiabilityParameters,
408411
/// A differentiability parameter does not conform to `Differentiable`.
@@ -431,7 +434,6 @@ class DerivativeFunctionTypeError
431434
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
432435
: functionType(functionType), kind(kind), value(Value()) {
433436
assert(kind == Kind::NoSemanticResults ||
434-
kind == Kind::MultipleSemanticResults ||
435437
kind == Kind::NoDifferentiabilityParameters);
436438
};
437439

@@ -574,15 +576,22 @@ namespace autodiff {
574576
/// `inout` parameter types.
575577
///
576578
/// The function type may have at most two parameter lists.
577-
///
578-
/// Remaps the original semantic result using `genericEnv`, if specified.
579-
void getFunctionSemanticResultTypes(
580-
AnyFunctionType *functionType,
581-
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
582-
GenericEnvironment *genericEnv = nullptr);
579+
void getFunctionSemanticResults(
580+
const AnyFunctionType *functionType,
581+
const IndexSubset *parameterIndices,
582+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes);
583+
584+
/// Returns the indices of semantic results for a given function.
585+
IndexSubset *getFunctionSemanticResultIndices(
586+
const AnyFunctionType *functionType,
587+
const IndexSubset *parameterIndices);
588+
589+
IndexSubset *getFunctionSemanticResultIndices(
590+
const AbstractFunctionDecl *AFD,
591+
const IndexSubset *parameterIndices);
583592

584593
/// Returns the lowered SIL parameter indices for the given AST parameter
585-
/// indices and `AnyfunctionType`.
594+
/// indices and `AnyFunctionType`.
586595
///
587596
/// Notable lowering-related changes:
588597
/// - AST tuple parameter types are exploded when lowered to SIL.

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,9 +4004,6 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none,
40044004
(DescriptiveDeclKind))
40054005
ERROR(autodiff_attr_original_void_result,none,
40064006
"cannot differentiate void function %0", (DeclName))
4007-
ERROR(autodiff_attr_original_multiple_semantic_results,none,
4008-
"cannot differentiate functions with both an 'inout' parameter and a "
4009-
"result", ())
40104007
ERROR(autodiff_attr_result_not_differentiable,none,
40114008
"can only differentiate functions with results that conform to "
40124009
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))

lib/AST/AutoDiff.cpp

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -180,39 +180,89 @@ void AnyFunctionType::getSubsetParameters(
180180
}
181181
}
182182

183-
void autodiff::getFunctionSemanticResultTypes(
184-
AnyFunctionType *functionType,
185-
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
186-
GenericEnvironment *genericEnv) {
183+
void autodiff::getFunctionSemanticResults(
184+
const AnyFunctionType *functionType,
185+
const IndexSubset *parameterIndices,
186+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
187187
auto &ctx = functionType->getASTContext();
188188

189-
// Remap type in `genericEnv`, if specified.
190-
auto remap = [&](Type type) {
191-
if (!genericEnv)
192-
return type;
193-
return genericEnv->mapTypeIntoContext(type);
194-
};
195-
196189
// Collect formal result type as a semantic result, unless it is
197190
// `Void`.
198191
auto formalResultType = functionType->getResult();
199192
if (auto *resultFunctionType =
200-
functionType->getResult()->getAs<AnyFunctionType>()) {
193+
functionType->getResult()->getAs<AnyFunctionType>())
201194
formalResultType = resultFunctionType->getResult();
195+
196+
unsigned resultIdx = 0;
197+
if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
198+
// Separate tuple elements into individual results.
199+
if (formalResultType->is<TupleType>()) {
200+
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
201+
resultTypes.emplace_back(elt.getType(), resultIdx++,
202+
/*isInout*/ false, /*isWrt*/ false);
203+
}
204+
} else {
205+
resultTypes.emplace_back(formalResultType, resultIdx++,
206+
/*isInout*/ false, /*isWrt*/ false);
207+
}
202208
}
203-
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
204-
result.push_back({remap(formalResultType), /*isInout*/ false});
205209

206-
// Collect `inout` parameters as semantic results.
207-
for (auto param : functionType->getParams())
208-
if (param.isInOut())
209-
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
210-
if (auto *resultFunctionType =
211-
functionType->getResult()->getAs<AnyFunctionType>()) {
212-
for (auto param : resultFunctionType->getParams())
213-
if (param.isInOut())
214-
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
210+
bool addNonWrts = resultTypes.empty();
211+
212+
// Collect wrt `inout` parameters as semantic results
213+
// As an extention, collect all (including non-wrt) inouts as results for
214+
// functions returning void.
215+
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
216+
unsigned curryOffset = 0) {
217+
for (auto paramAndIndex : enumerate(functionType->getParams())) {
218+
if (!paramAndIndex.value().isInOut())
219+
continue;
220+
221+
unsigned idx = paramAndIndex.index() + curryOffset;
222+
assert(idx < parameterIndices->getCapacity() &&
223+
"invalid parameter index");
224+
bool isWrt = parameterIndices->contains(idx);
225+
if (addNonWrts || isWrt)
226+
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
227+
resultIdx, /*isInout*/ true, isWrt);
228+
resultIdx += 1;
229+
}
230+
};
231+
232+
if (auto *resultFnType =
233+
functionType->getResult()->getAs<AnyFunctionType>()) {
234+
// Here we assume that the input is a function type with curried `Self`
235+
assert(functionType->getNumParams() == 1 && "unexpected function type");
236+
237+
collectSemanticResults(resultFnType);
238+
collectSemanticResults(functionType, resultFnType->getNumParams());
239+
} else
240+
collectSemanticResults(functionType);
241+
}
242+
243+
IndexSubset *
244+
autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType,
245+
const IndexSubset *parameterIndices) {
246+
auto &ctx = functionType->getASTContext();
247+
248+
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
249+
autodiff::getFunctionSemanticResults(functionType, parameterIndices,
250+
semanticResults);
251+
SmallVector<unsigned> resultIndices;
252+
unsigned cap = 0;
253+
for (const auto& result : semanticResults) {
254+
resultIndices.push_back(result.index);
255+
cap = std::max(cap, result.index + 1U);
215256
}
257+
258+
return IndexSubset::get(ctx, cap, resultIndices);
259+
}
260+
261+
IndexSubset *
262+
autodiff::getFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD,
263+
const IndexSubset *parameterIndices) {
264+
return getFunctionSemanticResultIndices(AFD->getInterfaceType()->castTo<AnyFunctionType>(),
265+
parameterIndices);
216266
}
217267

218268
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
@@ -399,9 +449,6 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
399449
case Kind::NoSemanticResults:
400450
OS << "has no semantic results ('Void' result)";
401451
break;
402-
case Kind::MultipleSemanticResults:
403-
OS << "has multiple semantic results";
404-
break;
405452
case Kind::NoDifferentiabilityParameters:
406453
OS << "has no differentiability parameters";
407454
break;

lib/AST/Type.cpp

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5549,32 +5549,43 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55495549
getSubsetParameters(parameterIndices, diffParams,
55505550
/*reverseCurryLevels*/ !makeSelfParamFirst);
55515551

5552-
// Get the original semantic result type.
5552+
// Get the original non-inout semantic result types.
55535553
SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
5554-
autodiff::getFunctionSemanticResultTypes(this, originalResults);
5554+
autodiff::getFunctionSemanticResults(this, parameterIndices, originalResults);
55555555
// Error if no original semantic results.
55565556
if (originalResults.empty())
55575557
return llvm::make_error<DerivativeFunctionTypeError>(
55585558
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
5559-
// Error if multiple original semantic results.
5560-
// TODO(TF-1250): Support functions with multiple semantic results.
5561-
if (originalResults.size() > 1)
5562-
return llvm::make_error<DerivativeFunctionTypeError>(
5563-
this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults);
5564-
auto originalResult = originalResults.front();
5565-
auto originalResultType = originalResult.type;
5566-
5567-
// Get the original semantic result type's `TangentVector` associated type.
5568-
auto resultTan =
5569-
originalResultType->getAutoDiffTangentSpace(lookupConformance);
5570-
// Error if original semantic result has no tangent space.
5571-
if (!resultTan) {
5572-
return llvm::make_error<DerivativeFunctionTypeError>(
5559+
5560+
// Accumulate non-inout result tangent spaces.
5561+
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
5562+
for (auto i : range(originalResults.size())) {
5563+
auto originalResult = originalResults[i];
5564+
auto originalResultType = originalResult.type;
5565+
5566+
// Voids currently have a defined tangent vector, so ignore them.
5567+
if (originalResultType->isVoid())
5568+
continue;
5569+
5570+
// Get the original semantic result type's `TangentVector` associated type.
5571+
// Error if a semantic result has no tangent space.
5572+
auto resultTan =
5573+
originalResultType->getAutoDiffTangentSpace(lookupConformance);
5574+
if (!resultTan)
5575+
return llvm::make_error<DerivativeFunctionTypeError>(
55735576
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5574-
std::make_pair(originalResultType, /*index*/ 0));
5577+
std::make_pair(originalResultType, unsigned(originalResult.index)));
5578+
5579+
if (!originalResult.isInout)
5580+
resultTanTypes.push_back(resultTan->getType());
5581+
else if (originalResult.isInout && !originalResult.isWrtParam)
5582+
inoutTanTypes.push_back(resultTan->getType());
55755583
}
5576-
auto resultTanType = resultTan->getType();
55775584

5585+
// Treat non-wrt inouts as semantic results for functions returning Void
5586+
if (resultTanTypes.empty())
5587+
resultTanTypes = inoutTanTypes;
5588+
55785589
// Compute the result linear map function type.
55795590
FunctionType *linearMapType;
55805591
switch (kind) {
@@ -5587,32 +5598,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55875598
//
55885599
// Case 2: original function has a non-wrt `inout` parameter.
55895600
// - Original: `(T0, inout T1, ...) -> Void`
5590-
// - Differential: `(T0.Tan, ...) -> T1.Tan`
5601+
// - Differential: `(T0.Tan, ...) -> T1.Tan`
55915602
//
55925603
// Case 3: original function has a wrt `inout` parameter.
5593-
// - Original: `(T0, inout T1, ...) -> Void`
5594-
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
5604+
// - Original: `(T0, inout T1, ...) -> Void`
5605+
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
55955606
SmallVector<AnyFunctionType::Param, 4> differentialParams;
5596-
bool hasInoutDiffParameter = false;
55975607
for (auto i : range(diffParams.size())) {
55985608
auto diffParam = diffParams[i];
55995609
auto paramType = diffParam.getPlainType();
56005610
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
56015611
// Error if parameter has no tangent space.
5602-
if (!paramTan) {
5612+
if (!paramTan)
56035613
return llvm::make_error<DerivativeFunctionTypeError>(
56045614
this,
56055615
DerivativeFunctionTypeError::Kind::
56065616
NonDifferentiableDifferentiabilityParameter,
56075617
std::make_pair(paramType, i));
5608-
}
5618+
56095619
differentialParams.push_back(AnyFunctionType::Param(
56105620
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
5611-
if (diffParam.isInOut())
5612-
hasInoutDiffParameter = true;
56135621
}
5614-
auto differentialResult =
5615-
hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType;
5622+
Type differentialResult;
5623+
if (resultTanTypes.empty()) {
5624+
differentialResult = ctx.TheEmptyTupleType;
5625+
} else if (resultTanTypes.size() == 1) {
5626+
differentialResult = resultTanTypes.front();
5627+
} else {
5628+
SmallVector<TupleTypeElt, 2> differentialResults;
5629+
for (auto i : range(resultTanTypes.size())) {
5630+
auto resultTanType = resultTanTypes[i];
5631+
differentialResults.push_back(
5632+
TupleTypeElt(resultTanType, Identifier()));
5633+
}
5634+
differentialResult = TupleType::get(differentialResults, ctx);
5635+
}
5636+
56165637
// FIXME: Verify ExtInfo state is correct, not working by accident.
56175638
FunctionType::ExtInfo info;
56185639
linearMapType =
@@ -5630,25 +5651,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56305651
// - Original: `(T0, inout T1, ...) -> Void`
56315652
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
56325653
//
5633-
// Case 3: original function has a wrt `inout` parameter.
5634-
// - Original: `(T0, inout T1, ...) -> Void`
5635-
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
5654+
// Case 3: original function has wrt `inout` parameters.
5655+
// - Original: `(T0, inout T1, ...) -> R`
5656+
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
56365657
SmallVector<TupleTypeElt, 4> pullbackResults;
5637-
bool hasInoutDiffParameter = false;
5658+
SmallVector<AnyFunctionType::Param, 2> inoutParams;
56385659
for (auto i : range(diffParams.size())) {
56395660
auto diffParam = diffParams[i];
56405661
auto paramType = diffParam.getPlainType();
56415662
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
56425663
// Error if parameter has no tangent space.
5643-
if (!paramTan) {
5664+
if (!paramTan)
56445665
return llvm::make_error<DerivativeFunctionTypeError>(
56455666
this,
56465667
DerivativeFunctionTypeError::Kind::
56475668
NonDifferentiableDifferentiabilityParameter,
56485669
std::make_pair(paramType, i));
5649-
}
5670+
56505671
if (diffParam.isInOut()) {
5651-
hasInoutDiffParameter = true;
5672+
if (paramType->isVoid())
5673+
continue;
5674+
inoutParams.push_back(diffParam);
56525675
continue;
56535676
}
56545677
pullbackResults.emplace_back(paramTan->getType());
@@ -5661,12 +5684,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56615684
} else {
56625685
pullbackResult = TupleType::get(pullbackResults, ctx);
56635686
}
5664-
auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter);
5665-
auto pullbackParam =
5666-
AnyFunctionType::Param(resultTanType, Identifier(), flags);
5687+
// First accumulate non-inout results as pullback parameters.
5688+
SmallVector<FunctionType::Param, 2> pullbackParams;
5689+
for (auto i : range(resultTanTypes.size())) {
5690+
auto resultTanType = resultTanTypes[i];
5691+
auto flags = ParameterTypeFlags().withInOut(false);
5692+
pullbackParams.push_back(AnyFunctionType::Param(
5693+
resultTanType, Identifier(), flags));
5694+
}
5695+
// Then append inout parameters.
5696+
for (auto i : range(inoutParams.size())) {
5697+
auto inoutParam = inoutParams[i];
5698+
auto inoutParamType = inoutParam.getPlainType();
5699+
auto inoutParamTan =
5700+
inoutParamType->getAutoDiffTangentSpace(lookupConformance);
5701+
auto flags = ParameterTypeFlags().withInOut(true);
5702+
pullbackParams.push_back(AnyFunctionType::Param(
5703+
inoutParamTan->getType(), Identifier(), flags));
5704+
}
56675705
// FIXME: Verify ExtInfo state is correct, not working by accident.
56685706
FunctionType::ExtInfo info;
5669-
linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info);
5707+
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
56705708
break;
56715709
}
56725710
}

0 commit comments

Comments
 (0)