@@ -3214,15 +3214,25 @@ class AnyFunctionType : public TypeBase {
3214
3214
return getExtInfo ().getRepresentation ();
3215
3215
}
3216
3216
3217
+ // / Appends the parameters indicated by `parameterIndices` to `results`.
3218
+ // /
3219
+ // / For curried function types: if `reverseCurryLevels` is true, append
3220
+ // / the `self` parameter last instead of first.
3221
+ // /
3222
+ // / TODO(TF-874): Simplify logic and remove the `reverseCurryLevels` flag.
3223
+ void getSubsetParameters (IndexSubset *parameterIndices,
3224
+ SmallVectorImpl<AnyFunctionType::Param> &results,
3225
+ bool reverseCurryLevels = false );
3226
+
3217
3227
// / Returns the derivative function type for the given parameter indices,
3218
3228
// / result index, derivative function kind, derivative function generic
3219
3229
// / signature (optional), and other auxiliary parameters.
3220
3230
// /
3221
3231
// / Preconditions:
3222
3232
// / - Parameters corresponding to parameter indices must conform to
3223
3233
// / `Differentiable`.
3224
- // / - The result corresponding to the result index must conform to
3225
- // / `Differentiable`.
3234
+ // / - There is one semantic function result type: either the formal original
3235
+ // / result or an `inout` parameter. It must conform to `Differentiable`.
3226
3236
// /
3227
3237
// / Typing rules, given:
3228
3238
// / - Original function type. Three cases:
@@ -3268,6 +3278,11 @@ class AnyFunctionType : public TypeBase {
3268
3278
// / original result | deriv. wrt result | deriv. wrt params
3269
3279
// / \endverbatim
3270
3280
// /
3281
+ // / The original type may have `inout` parameters. If so, the
3282
+ // / differential/pullback typing rules are more nuanced: see documentation for
3283
+ // / `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
3284
+ // / `inout` parameters behave as both parameters and results.
3285
+ // /
3271
3286
// / By default, if the original type has a `self` parameter list and parameter
3272
3287
// / indices include `self`, the computed derivative function type will return
3273
3288
// / a linear map taking/returning self's tangent *last* instead of first, for
@@ -3278,14 +3293,57 @@ class AnyFunctionType : public TypeBase {
3278
3293
// / derivative function types, e.g. when type-checking `@differentiable` and
3279
3294
// / `@derivative` attributes.
3280
3295
AnyFunctionType *getAutoDiffDerivativeFunctionType (
3281
- IndexSubset *parameterIndices, unsigned resultIndex,
3282
- AutoDiffDerivativeFunctionKind kind,
3296
+ IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
3283
3297
LookupConformanceFn lookupConformance,
3284
3298
GenericSignature derivativeGenericSignature = GenericSignature(),
3285
3299
bool makeSelfParamFirst = false);
3286
3300
3301
+ // / Returns the corresponding linear map function type for the given parameter
3302
+ // / indices, linear map function kind, and other auxiliary parameters.
3303
+ // /
3304
+ // / Preconditions:
3305
+ // / - Parameters corresponding to parameter indices must conform to
3306
+ // / `Differentiable`.
3307
+ // / - There is one semantic function result type: either the formal original
3308
+ // / result or an `inout` parameter. It must conform to `Differentiable`.
3309
+ // /
3310
+ // / Differential typing rules: takes "wrt" parameter derivatives and returns a
3311
+ // / "wrt" result derivative.
3312
+ // /
3313
+ // / - Case 1: original function has no `inout` parameters.
3314
+ // / - Original: `(T0, T1, ...) -> R`
3315
+ // / - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
3316
+ // / - Case 2: original function has a non-wrt `inout` parameter.
3317
+ // / - Original: `(T0, inout T1, ...) -> Void`
3318
+ // / - Differential: `(T0.Tan, ...) -> T1.Tan`
3319
+ // / - Case 3: original function has a wrt `inout` parameter.
3320
+ // / - Original: `(T0, inout T1, ...) -> Void`
3321
+ // / - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
3322
+ // /
3323
+ // / Pullback typing rules: takes a "wrt" result derivative and returns "wrt"
3324
+ // / parameter derivatives.
3325
+ // /
3326
+ // / - Case 1: original function has no `inout` parameters.
3327
+ // / - Original: `(T0, T1, ...) -> R`
3328
+ // / - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
3329
+ // / - Case 2: original function has a non-wrt `inout` parameter.
3330
+ // / - Original: `(T0, inout T1, ...) -> Void`
3331
+ // / - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
3332
+ // / - Case 3: original function has a wrt `inout` parameter.
3333
+ // / - Original: `(T0, inout T1, ...) -> Void`
3334
+ // / - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
3335
+ // /
3336
+ // / If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
3337
+ // / first. `makeSelfParamFirst` should be true when working with user-facing
3338
+ // / derivative function types, e.g. when type-checking `@differentiable` and
3339
+ // / `@derivative` attributes.
3340
+ AnyFunctionType *getAutoDiffDerivativeFunctionLinearMapType (
3341
+ IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
3342
+ LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false );
3343
+
3287
3344
// SWIFT_ENABLE_TENSORFLOW
3288
3345
AnyFunctionType *getWithoutDifferentiability () const ;
3346
+ // SWIFT_ENABLE_TENSORFLOW END
3289
3347
3290
3348
// / True if the parameter declaration it is attached to is guaranteed
3291
3349
// / to not persist the closure for longer than the duration of the call.
@@ -4420,6 +4478,28 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
4420
4478
return getParameters ().back ();
4421
4479
}
4422
4480
4481
+ struct IndirectMutatingParameterFilter {
4482
+ bool operator ()(SILParameterInfo param) const {
4483
+ return param.isIndirectMutating ();
4484
+ }
4485
+ };
4486
+ using IndirectMutatingParameterIter =
4487
+ llvm::filter_iterator<const SILParameterInfo *,
4488
+ IndirectMutatingParameterFilter>;
4489
+ using IndirectMutatingParameterRange =
4490
+ iterator_range<IndirectMutatingParameterIter>;
4491
+
4492
+ // / A range of SILParameterInfo for all indirect mutating parameters.
4493
+ IndirectMutatingParameterRange getIndirectMutatingParameters () const {
4494
+ return llvm::make_filter_range (getParameters (),
4495
+ IndirectMutatingParameterFilter ());
4496
+ }
4497
+
4498
+ // / Returns the number of indirect mutating parameters.
4499
+ unsigned getNumIndirectMutatingParameters () const {
4500
+ return llvm::count_if (getParameters (), IndirectMutatingParameterFilter ());
4501
+ }
4502
+
4423
4503
// / Get the generic signature used to apply the substitutions of a substituted function type
4424
4504
CanGenericSignature getSubstGenericSignature () const {
4425
4505
return GenericSigAndIsImplied.getPointer ();
@@ -4488,18 +4568,27 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
4488
4568
// / - Returns original results, followed by a differential function, which
4489
4569
// / takes "wrt" parameter derivatives and returns a "wrt" result derivative.
4490
4570
// /
4571
+ // / \verbatim
4491
4572
// / $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
4492
4573
// / ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
4493
4574
// / original results | derivatives wrt params | derivative wrt result
4575
+ // / \endverbatim
4494
4576
// /
4495
4577
// / VJP derivative type:
4496
4578
// / - Takes original parameters.
4497
4579
// / - Returns original results, followed by a pullback function, which
4498
4580
// / takes a "wrt" result derivative and returns "wrt" parameter derivatives.
4499
4581
// /
4582
+ // / \verbatim
4500
4583
// / $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
4501
4584
// / ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
4502
4585
// / original results | derivative wrt result | derivatives wrt params
4586
+ // / \endverbatim
4587
+ // /
4588
+ // / The original type may have `inout` parameters. If so, the
4589
+ // / differential/pullback typing rules are more nuanced: see documentation for
4590
+ // / `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
4591
+ // / `inout` parameters behave as both parameters and results.
4503
4592
// /
4504
4593
// / A "constrained derivative generic signature" is computed from
4505
4594
// / `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
0 commit comments