Skip to content

Commit 64361eb

Browse files
authored
[AutoDiff] Update differentiableFunction(from:) to use differentiable function constructor builtins. (#28470)
This PR updates `differentiableFunction(from:)` to use differentiable function constructor builtins added in #28467. The standard library no longer has non-top-level derivative registration. Resolves SR-11847. Unblocks SR-11849.
1 parent e1514cf commit 64361eb

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

stdlib/public/Differentiation/DifferentiationSupport.swift

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,15 @@ public func differentiableFunction<T : Differentiable, R : Differentiable>(
215215
from vjp: @escaping (T)
216216
-> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
217217
) -> @differentiable (T) -> R {
218-
func original(_ x: T) -> R {
219-
return vjp(x).value
220-
}
221-
@differentiating(original)
222-
func derivative(_ x: T)
223-
-> (value: R, pullback: (R.TangentVector) -> T.TangentVector) {
224-
return vjp(x)
225-
}
226-
return original
218+
Builtin.differentiableFunction_arity1(
219+
/*original*/ { vjp($0).value },
220+
/*jvp*/ { _ in
221+
fatalError("""
222+
Functions formed with `differentiableFunction(from:)` cannot yet \
223+
be used with differential-producing differential operators.
224+
""")
225+
},
226+
/*vjp*/ vjp)
227227
}
228228

229229
/// Create a differentiable function from a vector-Jacobian products function.
@@ -232,19 +232,16 @@ public func differentiableFunction<T, U, R>(
232232
from vjp: @escaping (T, U)
233233
-> (value: R, pullback: (R.TangentVector)
234234
-> (T.TangentVector, U.TangentVector))
235-
) -> @differentiable (T, U) -> R
236-
where T : Differentiable, U : Differentiable, R : Differentiable {
237-
func original(_ x: T, _ y: U) -> R {
238-
return vjp(x, y).value
239-
}
240-
@differentiating(original)
241-
func derivative(_ x: T, _ y: U)
242-
-> (value: R,
243-
pullback: (R.TangentVector)
244-
-> (T.TangentVector, U.TangentVector)) {
245-
return vjp(x, y)
246-
}
247-
return original
235+
) -> @differentiable (T, U) -> R {
236+
Builtin.differentiableFunction_arity2(
237+
/*original*/ { vjp($0, $1).value },
238+
/*jvp*/ { _, _ in
239+
fatalError("""
240+
Functions formed with `differentiableFunction(from:)` cannot yet \
241+
be used with differential-producing differential operators.
242+
""")
243+
},
244+
/*vjp*/ vjp)
248245
}
249246

250247
public extension Differentiable {

0 commit comments

Comments
 (0)