File tree Expand file tree Collapse file tree 2 files changed +32
-36
lines changed
private/DifferentiationUnittest Expand file tree Collapse file tree 2 files changed +32
-36
lines changed Original file line number Diff line number Diff line change @@ -34,6 +34,38 @@ public func withLeakChecking(
34
34
file: file, line: line)
35
35
}
36
36
37
+ @inlinable
38
+ @_semantics("autodiff.nonvarying")
39
+ public func withoutDerivative<T, R>(at x: T, in body: (T) -> R) -> R
40
+ {
41
+ body(x)
42
+ }
43
+
44
+ public extension Differentiable {
45
+ /// Applies the given closure to the derivative of `self`.
46
+ ///
47
+ /// Returns `self` like an identity function. When the return value is used in
48
+ /// a context where it is differentiated with respect to, applies the given
49
+ /// closure to the derivative of the return value.
50
+ @inlinable
51
+ @differentiable(reverse, wrt: self)
52
+ func withDerivative(_ body: @escaping (inout TangentVector) -> Void) -> Self {
53
+ return self
54
+ }
55
+
56
+ @inlinable
57
+ @derivative(of: withDerivative)
58
+ internal func _vjpWithDerivative(
59
+ _ body: @escaping (inout TangentVector) -> Void
60
+ ) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
61
+ return (self, { grad in
62
+ var grad = grad
63
+ body(&grad)
64
+ return grad
65
+ })
66
+ }
67
+ }
68
+
37
69
public extension TestSuite {
38
70
/// Execute test function and check expected leak count.
39
71
func testWithLeakChecking(
Original file line number Diff line number Diff line change @@ -31,42 +31,6 @@ public func withoutDerivative<T>(at x: T) -> T {
31
31
x
32
32
}
33
33
34
- /// Applies the given closure `body` to `x`. When used in a context where `x` is
35
- /// being differentiated with respect to, this function will not produce any
36
- /// derivative at `x`.
37
- // FIXME: Support throws-rethrows.
38
- @inlinable
39
- @inline ( __always)
40
- @_semantics ( " autodiff.nonvarying " )
41
- public func withoutDerivative< T, R> ( at x: T , in body: ( T ) -> R ) -> R {
42
- body ( x)
43
- }
44
-
45
- public extension Differentiable {
46
- /// Applies the given closure to the derivative of `self`.
47
- ///
48
- /// Returns `self` like an identity function. When the return value is used in
49
- /// a context where it is differentiated with respect to, applies the given
50
- /// closure to the derivative of the return value.
51
- @inlinable
52
- @differentiable ( reverse, wrt: self )
53
- func withDerivative( _ body: @escaping ( inout TangentVector ) -> Void ) -> Self {
54
- return self
55
- }
56
-
57
- @inlinable
58
- @derivative ( of: withDerivative)
59
- internal func _vjpWithDerivative(
60
- _ body: @escaping ( inout TangentVector ) -> Void
61
- ) -> ( value: Self , pullback: ( TangentVector ) -> TangentVector ) {
62
- return ( self , { grad in
63
- var grad = grad
64
- body ( & grad)
65
- return grad
66
- } )
67
- }
68
- }
69
-
70
34
//===----------------------------------------------------------------------===//
71
35
// Diagnostics
72
36
//===----------------------------------------------------------------------===//
You can’t perform that action at this time.
0 commit comments