Skip to content

Commit 8ddbc68

Browse files
authored
Add missed vjp / jvp functions for floating-point constructors (#64417)
Fixes #64406
1 parent c6508a4 commit 8ddbc68

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ extension ${Self}: Differentiable {
5151
// Derivatives
5252
//===----------------------------------------------------------------------===//
5353

54+
/// Derivatives of constructors.
55+
${Availability(bits)}
56+
extension ${Self} {
57+
@inlinable
58+
@_transparent
59+
@derivative(of: init(_:))
60+
static func _vjpInit(x: ${Self})
61+
-> (value: ${Self}, pullback: (${Self}) -> ${Self}) {
62+
return (x, { v in v })
63+
}
64+
65+
@inlinable
66+
@_transparent
67+
@derivative(of: init(_:))
68+
static func _jvpInit(x: ${Self})
69+
-> (value: ${Self}, differential: (${Self}) -> ${Self}) {
70+
return (x, { dx in dx })
71+
}
72+
}
73+
5474
/// Derivatives of standard unary operators.
5575
${Availability(bits)}
5676
extension ${Self} {

test/AutoDiff/stdlib/floating_point.swift.gyb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
3434
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
3535
%end
3636

37+
FloatingPointDerivativeTests.test("${Self}.init") {
38+
expectEqual(1, gradient(at: ${Self}(4), of: ${Self}.init(_:)))
39+
expectEqual(10, pullback(at: ${Self}(4), of: ${Self}.init(_:))(${Self}(10)))
40+
41+
expectEqual(1, derivative(at: ${Self}(4), of: ${Self}.init(_:)))
42+
expectEqual(10, differential(at: ${Self}(4), of: ${Self}.init(_:))(${Self}(10)))
43+
}
44+
3745
FloatingPointDerivativeTests.test("${Self}.+") {
3846
expectEqual((1, 1), gradient(at: ${Self}(4), ${Self}(5), of: +))
3947
expectEqual((10, 10), pullback(at: ${Self}(4), ${Self}(5), of: +)(${Self}(10)))

0 commit comments

Comments
 (0)