Skip to content

Commit 1782d72

Browse files
authored
[AutoDiff] [stdlib] Add a top-level 'withoutDerivative(at:in:)' API. (#25552)
The existing `withoutDerivative()` method on `Differentiable` is not general enough to support generic algorithms that are conditionally differentiable, for example: ```swift @differentiable(where T: Differentiable) func foo<T: Numeric>(x: T) -> T { 2 * x.withoutDerivative() } // Error: 'x' does not conform to 'Differentiable'. ``` We add a top-level function that makes this possible. ```swift @differentiable(where T: Differentiable) func foo<T: Numeric>(x: T) -> T { withoutDerivative(at: x) { x in 2 * x } // Okay! } ```
1 parent aa6d64c commit 1782d72

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

stdlib/public/core/AutoDiff.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,24 @@ public extension Differentiable where TangentVector == Self {
140140
}
141141

142142
public extension Differentiable {
143-
/// Identity function that stops gradients from propagating.
143+
/// Identity function that stops derivatives from propagating.
144+
@inlinable
144145
@inline(__always)
145146
@_semantics("autodiff.nonvarying")
146147
func withoutDerivative() -> Self { return self }
147148
}
148149

150+
/// Applies the given closure `body` to `x`. When used in a context where `x` is
151+
/// being differentiated with respect to, this function will not produce any
152+
/// derivative at `x`.
153+
// FIXME: Support throws-rethrows.
154+
@inlinable
155+
@inline(__always)
156+
@_semantics("autodiff.nonvarying")
157+
public func withoutDerivative<T, R>(at x: T, in body: (T) -> R) -> R {
158+
body(x)
159+
}
160+
149161
//===----------------------------------------------------------------------===//
150162
// Functional utilities
151163
//===----------------------------------------------------------------------===//

test/AutoDiff/custom_derivatives.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,12 @@ CustomDerivativesTests.test("ModifyGradientOfSum") {
8787
})
8888
}
8989

90+
CustomDerivativesTests.test("WithoutDerivative") {
91+
expectEqual(0, gradient(at: Float(4)) { x in
92+
withoutDerivative(at: x) { x in
93+
sinf(x) + cosf(x)
94+
}
95+
})
96+
}
97+
9098
runAllTests()

0 commit comments

Comments
 (0)