Skip to content

Commit 2cddff0

Browse files
committed
[AutoDiff] Type-checking support for inout parameter differentiation.
Semantically, an `inout` parameter is both a parameter and a result. `@differentiable` and `@derivative` attributes now support original functions with one "semantic result": either a formal result or an `inout` parameter. Derivative typing rules for functions with `inout` parameters are now defined. The differential/pullback type of a function with `inout` differentiability parameters also has `inout` parameters. This is ideal for performance. Differential typing rules: - Case 1: original function has no `inout` parameters. - Original: `(T0, T1, ...) -> R` - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` - Case 2: original function has a non-wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Differential: `(T0.Tan, ...) -> T1.Tan` - Case 3: original function has a wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` Pullback typing rules: - Case 1: original function has no `inout` parameters. - Original: `(T0, T1, ...) -> R` - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` - Case 2: original function has a non-wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - Case 3: original function has a wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` Resolves TF-1164.
1 parent aef19c9 commit 2cddff0

File tree

10 files changed

+625
-230
lines changed

10 files changed

+625
-230
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ struct AutoDiffConfig {
122122
SWIFT_DEBUG_DUMP;
123123
};
124124

125+
/// A semantic function result type: either a formal function result type or
126+
/// an `inout` parameter type. Used in derivative function type calculation.
127+
struct AutoDiffSemanticFunctionResultType {
128+
Type type;
129+
bool isInout;
130+
};
131+
125132
/// Key for caching SIL derivative function types.
126133
struct SILAutoDiffDerivativeFunctionKey {
127134
SILFunctionType *originalType;
@@ -271,11 +278,17 @@ using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
271278
/// Automatic differentiation utility namespace.
272279
namespace autodiff {
273280

274-
/// Appends the subset's parameter's types to `results`, in the order in
275-
/// which they appear in the function type.
276-
void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
277-
SmallVectorImpl<Type> &results,
278-
bool reverseCurryLevels = false);
281+
/// Given a function type, collects its semantic result types in type order
282+
/// into `result`: first, the formal result type (if non-`Void`), followed by
283+
/// `inout` parameter types.
284+
///
285+
/// The function type may have at most two parameter lists.
286+
///
287+
/// Remaps the original semantic result using `genericEnv`, if specified.
288+
void getFunctionSemanticResultTypes(
289+
AnyFunctionType *functionType,
290+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
291+
GenericEnvironment *genericEnv = nullptr);
279292

280293
/// "Constrained" derivative generic signatures require all differentiability
281294
/// parameters to conform to the `Differentiable` protocol.

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2978,8 +2978,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
29782978
(DeclName, DeclName))
29792979

29802980
// @differentiable
2981-
ERROR(differentiable_attr_void_result,none,
2982-
"cannot differentiate void function %0", (DeclName))
29832981
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
29842982
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
29852983
"attribute for transpose registration instead", ())
@@ -3097,6 +3095,11 @@ ERROR(autodiff_attr_original_decl_none_valid_found,none,
30973095
"could not find function %0 with expected type %1", (DeclNameRef, Type))
30983096
ERROR(autodiff_attr_original_decl_not_same_type_context,none,
30993097
"%0 is not defined in the current type context", (DeclNameRef))
3098+
ERROR(autodiff_attr_original_void_result,none,
3099+
"cannot differentiate void function %0", (DeclName))
3100+
ERROR(autodiff_attr_original_multiple_semantic_results,none,
3101+
"cannot differentiate functions with both an 'inout' parameter and a "
3102+
"result", ())
31003103

31013104
// differentiation `wrt` parameters clause
31023105
ERROR(diff_function_no_parameters,none,

include/swift/AST/Types.h

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3214,15 +3214,25 @@ class AnyFunctionType : public TypeBase {
32143214
return getExtInfo().getRepresentation();
32153215
}
32163216

3217+
/// Appends the parameters indicated by `parameterIndices` to `results`.
3218+
///
3219+
/// For curried function types: if `reverseCurryLevels` is true, append
3220+
/// the `self` parameter last instead of first.
3221+
///
3222+
/// TODO(TF-874): Simplify logic and remove the `reverseCurryLevels` flag.
3223+
void getSubsetParameters(IndexSubset *parameterIndices,
3224+
SmallVectorImpl<AnyFunctionType::Param> &results,
3225+
bool reverseCurryLevels = false);
3226+
32173227
/// Returns the derivative function type for the given parameter indices,
32183228
/// result index, derivative function kind, derivative function generic
32193229
/// signature (optional), and other auxiliary parameters.
32203230
///
32213231
/// Preconditions:
32223232
/// - Parameters corresponding to parameter indices must conform to
32233233
/// `Differentiable`.
3224-
/// - The result corresponding to the result index must conform to
3225-
/// `Differentiable`.
3234+
/// - There is one semantic function result type: either the formal original
3235+
/// result or an `inout` parameter. It must conform to `Differentiable`.
32263236
///
32273237
/// Typing rules, given:
32283238
/// - Original function type. Three cases:
@@ -3268,6 +3278,11 @@ class AnyFunctionType : public TypeBase {
32683278
/// original result | deriv. wrt result | deriv. wrt params
32693279
/// \endverbatim
32703280
///
3281+
/// The original type may have `inout` parameters. If so, the
3282+
/// differential/pullback typing rules are more nuanced: see documentation for
3283+
/// `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
3284+
/// `inout` parameters behave as both parameters and results.
3285+
///
32713286
/// By default, if the original type has a `self` parameter list and parameter
32723287
/// indices include `self`, the computed derivative function type will return
32733288
/// a linear map taking/returning self's tangent *last* instead of first, for
@@ -3278,14 +3293,57 @@ class AnyFunctionType : public TypeBase {
32783293
/// derivative function types, e.g. when type-checking `@differentiable` and
32793294
/// `@derivative` attributes.
32803295
AnyFunctionType *getAutoDiffDerivativeFunctionType(
3281-
IndexSubset *parameterIndices, unsigned resultIndex,
3282-
AutoDiffDerivativeFunctionKind kind,
3296+
IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
32833297
LookupConformanceFn lookupConformance,
32843298
GenericSignature derivativeGenericSignature = GenericSignature(),
32853299
bool makeSelfParamFirst = false);
32863300

3301+
/// Returns the corresponding linear map function type for the given parameter
3302+
/// indices, linear map function kind, and other auxiliary parameters.
3303+
///
3304+
/// Preconditions:
3305+
/// - Parameters corresponding to parameter indices must conform to
3306+
/// `Differentiable`.
3307+
/// - There is one semantic function result type: either the formal original
3308+
/// result or an `inout` parameter. It must conform to `Differentiable`.
3309+
///
3310+
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
3311+
/// "wrt" result derivative.
3312+
///
3313+
/// - Case 1: original function has no `inout` parameters.
3314+
/// - Original: `(T0, T1, ...) -> R`
3315+
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
3316+
/// - Case 2: original function has a non-wrt `inout` parameter.
3317+
/// - Original: `(T0, inout T1, ...) -> Void`
3318+
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
3319+
/// - Case 3: original function has a wrt `inout` parameter.
3320+
/// - Original: `(T0, inout T1, ...) -> Void`
3321+
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
3322+
///
3323+
/// Pullback typing rules: takes a "wrt" result derivative and returns "wrt"
3324+
/// parameter derivatives.
3325+
///
3326+
/// - Case 1: original function has no `inout` parameters.
3327+
/// - Original: `(T0, T1, ...) -> R`
3328+
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
3329+
/// - Case 2: original function has a non-wrt `inout` parameter.
3330+
/// - Original: `(T0, inout T1, ...) -> Void`
3331+
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
3332+
/// - Case 3: original function has a wrt `inout` parameter.
3333+
/// - Original: `(T0, inout T1, ...) -> Void`
3334+
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
3335+
///
3336+
/// If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
3337+
/// first. `makeSelfParamFirst` should be true when working with user-facing
3338+
/// derivative function types, e.g. when type-checking `@differentiable` and
3339+
/// `@derivative` attributes.
3340+
AnyFunctionType *getAutoDiffDerivativeFunctionLinearMapType(
3341+
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
3342+
LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false);
3343+
32873344
// SWIFT_ENABLE_TENSORFLOW
32883345
AnyFunctionType *getWithoutDifferentiability() const;
3346+
// SWIFT_ENABLE_TENSORFLOW END
32893347

32903348
/// True if the parameter declaration it is attached to is guaranteed
32913349
/// to not persist the closure for longer than the duration of the call.
@@ -4420,6 +4478,28 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
44204478
return getParameters().back();
44214479
}
44224480

4481+
struct IndirectMutatingParameterFilter {
4482+
bool operator()(SILParameterInfo param) const {
4483+
return param.isIndirectMutating();
4484+
}
4485+
};
4486+
using IndirectMutatingParameterIter =
4487+
llvm::filter_iterator<const SILParameterInfo *,
4488+
IndirectMutatingParameterFilter>;
4489+
using IndirectMutatingParameterRange =
4490+
iterator_range<IndirectMutatingParameterIter>;
4491+
4492+
/// A range of SILParameterInfo for all indirect mutating parameters.
4493+
IndirectMutatingParameterRange getIndirectMutatingParameters() const {
4494+
return llvm::make_filter_range(getParameters(),
4495+
IndirectMutatingParameterFilter());
4496+
}
4497+
4498+
/// Returns the number of indirect mutating parameters.
4499+
unsigned getNumIndirectMutatingParameters() const {
4500+
return llvm::count_if(getParameters(), IndirectMutatingParameterFilter());
4501+
}
4502+
44234503
/// Get the generic signature used to apply the substitutions of a substituted function type
44244504
CanGenericSignature getSubstGenericSignature() const {
44254505
return GenericSigAndIsImplied.getPointer();
@@ -4488,18 +4568,27 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
44884568
/// - Returns original results, followed by a differential function, which
44894569
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
44904570
///
4571+
/// \verbatim
44914572
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
44924573
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
44934574
/// original results | derivatives wrt params | derivative wrt result
4575+
/// \endverbatim
44944576
///
44954577
/// VJP derivative type:
44964578
/// - Takes original parameters.
44974579
/// - Returns original results, followed by a pullback function, which
44984580
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
44994581
///
4582+
/// \verbatim
45004583
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
45014584
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
45024585
/// original results | derivative wrt result | derivatives wrt params
4586+
/// \endverbatim
4587+
///
4588+
/// The original type may have `inout` parameters. If so, the
4589+
/// differential/pullback typing rules are more nuanced: see documentation for
4590+
/// `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
4591+
/// `inout` parameters behave as both parameters and results.
45034592
///
45044593
/// A "constrained derivative generic signature" is computed from
45054594
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is

lib/AST/AutoDiff.cpp

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#include "swift/AST/ASTContext.h"
1413
#include "swift/AST/AutoDiff.h"
1514
#include "swift/AST/ASTContext.h"
15+
#include "swift/AST/GenericEnvironment.h"
1616
#include "swift/AST/Module.h"
1717
#include "swift/AST/TypeCheckRequests.h"
1818
#include "swift/AST/Types.h"
@@ -73,13 +73,11 @@ static unsigned countNumFlattenedElementTypes(Type type) {
7373
}
7474

7575
// TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag.
76-
// See TF-874 for WIP.
77-
void autodiff::getSubsetParameterTypes(IndexSubset *subset,
78-
AnyFunctionType *type,
79-
SmallVectorImpl<Type> &results,
80-
bool reverseCurryLevels) {
76+
void AnyFunctionType::getSubsetParameters(
77+
IndexSubset *parameterIndices,
78+
SmallVectorImpl<AnyFunctionType::Param> &results, bool reverseCurryLevels) {
8179
SmallVector<AnyFunctionType *, 2> curryLevels;
82-
unwrapCurryLevels(type, curryLevels);
80+
unwrapCurryLevels(this, curryLevels);
8381

8482
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
8583
unsigned currentOffset = 0;
@@ -100,8 +98,43 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
10098
unsigned parameterIndexOffset =
10199
curryLevelParameterIndexOffsets[curryLevelIndex];
102100
for (unsigned paramIndex : range(curryLevel->getNumParams()))
103-
if (subset->contains(parameterIndexOffset + paramIndex))
104-
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
101+
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
102+
results.push_back(curryLevel->getParams()[paramIndex]);
103+
}
104+
}
105+
106+
void autodiff::getFunctionSemanticResultTypes(
107+
AnyFunctionType *functionType,
108+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
109+
GenericEnvironment *genericEnv) {
110+
auto &ctx = functionType->getASTContext();
111+
112+
// Remap type in `genericEnv`, if specified.
113+
auto remap = [&](Type type) {
114+
if (!genericEnv)
115+
return type;
116+
return genericEnv->mapTypeIntoContext(type);
117+
};
118+
119+
// Collect formal result type as a semantic result, unless it is
120+
// `Void`.
121+
auto formalResultType = functionType->getResult();
122+
if (auto *resultFunctionType =
123+
functionType->getResult()->getAs<AnyFunctionType>()) {
124+
formalResultType = resultFunctionType->getResult();
125+
}
126+
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
127+
result.push_back({remap(formalResultType), /*isInout*/ false});
128+
129+
// Collect `inout` parameters as semantic results.
130+
for (auto param : functionType->getParams())
131+
if (param.isInOut())
132+
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
133+
if (auto *resultFunctionType =
134+
functionType->getResult()->getAs<AnyFunctionType>()) {
135+
for (auto param : resultFunctionType->getParams())
136+
if (param.isInOut())
137+
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
105138
}
106139
}
107140

0 commit comments

Comments
 (0)