-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Add 'Builtin.applyTranspose*'. #28469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ntiable(linear)` function from component functons. * `Builtin.differentiableFunction_*` Takes an original function, a JVP function and a VJP function and returns a `@differentiable` function. Pseudo-declaration: ```swift func differentiableFunction_arity{arity}[_throws]?{throws}<T...{arity}, U>( _ original: __owned @escaping (T...{arity}) {throws}? -> U, _ jvp: __owned @escaping (T...{arity}) {throws}? -> (value: U, differential: (T.TangentVector...{arity}) -> U.TangentVector), _ vjp: __owned @escaping (T...{arity}) {throws}? -> (value: U, pullback: (U.TangentVector) -> (T.TangentVector...{arity})) ) -> @differentiable (T...{arity}) {throws}? -> U where T...{arity} : Differentiable, U : Differentiable ``` * `Builtin.linearFunction_*` Takes an original function and a transpose function and returns a `@differentiable` function. Pseudo-declaration: ```swift func linearFunction_arity{arity}[_throws]?{throws}<T...{arity}, U>( _ original: __owned @escaping (T...{arity}) {throws}? -> U, _ transpose: __owned @escaping (U.TangentVector) {throws}? -> (T.TangentVector...{arity}) ) -> @differentiable (T...{arity}) {throws}? -> U where T...{arity} : Differentiable & AdditiveArithmetic, U : Differentiable & AdditiveArithmetic ``` These builtins will be used to write unit tests for `@differentiable` and `@differentiable(linear)` function types that do not necessarily depend on the differentiation transform. TODO: - SR-11848: For robustness, we need SIL FileCheck tests for all AD builtins. These have not been added for `Builtin.autodiffApply*`, so I'm leaving this as a future task. - SR-11847: Update `differentiableFunction(from:)` to use `Builtin.differentiableFunction*` in its implementation. - SR-11849: Disallow non-top-level derivative registration. Resolves SR-11846.
Add support for applying the transpose function in a `@differentiable(linear)` function. The `Builtin.applyTranspose*` builtin takes a `@differentiable(linear)` function and a tangent vector, and returns the result of applying the transpose to the tangent vector. Pseudo-declaration: ```swift func applyTranspose_arity{arity}[_rethrowws?{r}]<T...{arity}, R>( _: @differentiable (T...) {r}?throws -> R, _: R ) {r}?rethrows -> (T...) where T: Differentiable & AdditiveArithmetic, R: Differentiable & AdditiveArithmetic ``` This patch also renames `Builtin.autodiffApply` to `Builtin.applyDerivative` for clarity. Resolves SR-11844 and SR-11851.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Note for self: Builtin.apply{Derivative,Transpose}*
and Builtin.{differentiable,linear}Function*
are all higher-order functions.
apply{Derivative,Transpose}
are rethrows
because they may throw depending on whether the argument function throws.
{differentiable,linear}Function
are not throwing because they merely form a bundle of functions.
I just unified the terminology to just "throws" for simplicity :) The |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
3 similar comments
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
Add support for applying the transpose function in a
@differentiable(linear)
function. TheBuiltin.applyTranspose*
builtin takes a@differentiable(linear)
function and a tangent vector, and returns the result of applying the transpose to the tangent vector.Pseudo-declaration:
This patch also renames
Builtin.autodiffApply
toBuiltin.applyDerivative
for clarity, and fixes a bug inLinearDifferentiableSILFunctionTypeLowering
where it expected 3 component values inrebuildAggregate
instead of 2.Resolves SR-11844 and SR-11851.