Skip to content

Commit 627ac35

Browse files
authored
[AutoDiff] Add JVPs for tgmath functions. (#29290)
Add JVPs for tgmath functions except `remainder(_:_:)` and `fmod(_:_:)`. JVP/VJP functions for unary tgmath functions are identical, but there is no way to register a derivative function as both a JVP/VJP of the same original function. TF-1111 tracks this; the issue will naturally resolve when linear functions and transposition are done and VJPs are removed. Partially resolves TF-1108.
1 parent 76d11a0 commit 627ac35

File tree

7 files changed

+154
-93
lines changed

7 files changed

+154
-93
lines changed

stdlib/public/Platform/tgmath_derivatives.swift.gyb

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
1212
//
1313
//===----------------------------------------------------------------------===//
14+
// This file defines derivatives for tgmath functions.
15+
//===----------------------------------------------------------------------===//
1416

1517
@usableFromInline
16-
@derivative(of: sqrt)
17-
func _vjpSqrt<T: FloatingPoint & Differentiable> (
18-
_ x: T
19-
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
20-
let value = sqrt(x)
21-
return (value, { v in v / (2 * value) })
18+
@derivative(of: fma)
19+
func _jvpFma<T: FloatingPoint & Differentiable> (
20+
_ x: T,
21+
_ y: T,
22+
_ z: T
23+
) -> (value: T, differential: (T, T, T) -> T) where T == T.TangentVector {
24+
return (fma(x, y, z), { (dx, dy, dz) in dx * y + dy * x + dz })
2225
}
2326

2427
@usableFromInline
@@ -31,6 +34,18 @@ func _vjpFma<T: FloatingPoint & Differentiable> (
3134
return (fma(x, y, z), { v in (v * y, v * x, v) })
3235
}
3336

37+
@usableFromInline
38+
@derivative(of: remainder)
39+
func _jvpRemainder<T: FloatingPoint & Differentiable> (
40+
_ x: T,
41+
_ y: T
42+
) -> (value: T, differential: (T, T) -> T) where T == T.TangentVector {
43+
fatalError("""
44+
Unimplemented JVP for 'remainder(_:)'. \
45+
https://bugs.swift.org/browse/TF-1108 tracks this issue
46+
""")
47+
}
48+
3449
@usableFromInline
3550
@derivative(of: remainder)
3651
func _vjpRemainder<T: FloatingPoint & Differentiable> (
@@ -40,6 +55,18 @@ func _vjpRemainder<T: FloatingPoint & Differentiable> (
4055
return (remainder(x, y), { v in (v, -v * ((x / y).rounded(.toNearestOrEven))) })
4156
}
4257

58+
@usableFromInline
59+
@derivative(of: fmod)
60+
func _jvpFmod<T: FloatingPoint & Differentiable> (
61+
_ x: T,
62+
_ y: T
63+
) -> (value: T, differential: (T, T) -> T) where T == T.TangentVector {
64+
fatalError("""
65+
Unimplemented JVP for 'fmod(_:)'. \
66+
https://bugs.swift.org/browse/TF-1108 tracks this issue
67+
""")
68+
}
69+
4370
@usableFromInline
4471
@derivative(of: fmod)
4572
func _vjpFmod<T: FloatingPoint & Differentiable> (
@@ -49,173 +76,188 @@ func _vjpFmod<T: FloatingPoint & Differentiable> (
4976
return (fmod(x, y), { v in (v, -v * ((x / y).rounded(.towardZero))) })
5077
}
5178

79+
%for derivative_kind in ['jvp', 'vjp']:
80+
% linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
81+
@usableFromInline
82+
@derivative(of: sqrt)
83+
func _${derivative_kind}Sqrt<T: FloatingPoint & Differentiable> (
84+
_ x: T
85+
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
86+
let value = sqrt(x)
87+
return (value, { v in v / (2 * value) })
88+
}
89+
5290
@usableFromInline
5391
@derivative(of: ceil)
54-
func _vjpCeil<T: FloatingPoint & Differentiable> (
92+
func _${derivative_kind}Ceil<T: FloatingPoint & Differentiable> (
5593
_ x: T
56-
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
94+
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
5795
return (ceil(x), { v in 0 })
5896
}
5997

6098
@usableFromInline
6199
@derivative(of: floor)
62-
func _vjpFloor<T: FloatingPoint & Differentiable> (
100+
func _${derivative_kind}Floor<T: FloatingPoint & Differentiable> (
63101
_ x: T
64-
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
102+
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
65103
return (floor(x), { v in 0 })
66104
}
67105

68106
@usableFromInline
69107
@derivative(of: round)
70-
func _vjpRound<T: FloatingPoint & Differentiable> (
108+
func _${derivative_kind}Round<T: FloatingPoint & Differentiable> (
71109
_ x: T
72-
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
110+
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
73111
return (round(x), { v in 0 })
74112
}
75113

76114
@usableFromInline
77115
@derivative(of: trunc)
78-
func _vjpTrunc<T: FloatingPoint & Differentiable> (
116+
func _${derivative_kind}Trunc<T: FloatingPoint & Differentiable> (
79117
_ x: T
80-
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
118+
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
81119
return (trunc(x), { v in 0 })
82120
}
121+
%end # for derivative_kind in ['jvp', 'vjp']:
83122

84-
%for T in ['Float', 'Double', 'Float80']:
85-
% if T == 'Float80':
123+
%for derivative_kind in ['jvp', 'vjp']:
124+
% linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
125+
% for T in ['Float', 'Double', 'Float80']:
126+
% if T == 'Float80':
86127
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
87-
% end
128+
% end
88129
@inlinable
89130
@derivative(of: exp)
90-
func _vjpExp(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
131+
func _${derivative_kind}Exp(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
91132
let value = exp(x)
92133
return (value, { v in value * v })
93134
}
94135

95136
@inlinable
96137
@derivative(of: exp2)
97-
func _vjpExp2(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
138+
func _${derivative_kind}Exp2(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
98139
let value = exp2(x)
99140
return (value, { v in v * ${T}(M_LN2) * value })
100141
}
101142

102143
@inlinable
103144
@derivative(of: log)
104-
func _vjpLog(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
145+
func _${derivative_kind}Log(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
105146
return (log(x), { v in v / x })
106147
}
107148

108149
@inlinable
109150
@derivative(of: log10)
110-
func _vjpLog10(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
151+
func _${derivative_kind}Log10(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
111152
return (log10(x), { v in v * ${T}(M_LOG10E) / x })
112153
}
113154

114155
@inlinable
115156
@derivative(of: log2)
116-
func _vjpLog2(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
157+
func _${derivative_kind}Log2(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
117158
return (log2(x), { v in v / (${T}(M_LN2) * x) })
118159
}
119160

120161
@inlinable
121162
@derivative(of: sin)
122-
func _vjpSin(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
163+
func _${derivative_kind}Sin(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
123164
return (sin(x), { v in v * cos(x) })
124165
}
125166

126167
@inlinable
127168
@derivative(of: cos)
128-
func _vjpCos(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
169+
func _${derivative_kind}Cos(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
129170
return (cos(x), { v in -v * sin(x) })
130171
}
131172

132173
@inlinable
133174
@derivative(of: tan)
134-
func _vjpTan(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
175+
func _${derivative_kind}Tan(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
135176
let value = tan(x)
136177
return (value, { v in v * (1 + value * value) })
137178
}
138179

139180
@inlinable
140181
@derivative(of: asin)
141-
func _vjpAsin(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
182+
func _${derivative_kind}Asin(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
142183
return (asin(x), { v in v / sqrt(1 - x * x) })
143184
}
144185

145186
@inlinable
146187
@derivative(of: acos)
147-
func _vjpAcos(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
188+
func _${derivative_kind}Acos(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
148189
return (acos(x), { v in -v / sqrt(1 - x * x) })
149190
}
150191

151192
@inlinable
152193
@derivative(of: atan)
153-
func _vjpAtan(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
194+
func _${derivative_kind}Atan(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
154195
return (atan(x), { v in v / (1 + x * x) })
155196
}
156197

157198
@inlinable
158199
@derivative(of: sinh)
159-
func _vjpSinh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
200+
func _${derivative_kind}Sinh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
160201
return (sinh(x), { v in v * cosh(x) })
161202
}
162203

163204
@inlinable
164205
@derivative(of: cosh)
165-
func _vjpCosh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
206+
func _${derivative_kind}Cosh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
166207
return (cosh(x), { v in v * sinh(x) })
167208
}
168209

169210
@inlinable
170211
@derivative(of: tanh)
171-
func _vjpTanh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
212+
func _${derivative_kind}Tanh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
172213
let value = tanh(x)
173214
return (value, { v in v * (1 - value * value) })
174215
}
175216

176217
@inlinable
177218
@derivative(of: asinh)
178-
func _vjpAsinh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
219+
func _${derivative_kind}Asinh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
179220
return (asinh(x), { v in v / sqrt(1 + x * x) })
180221
}
181222

182223
@inlinable
183224
@derivative(of: acosh)
184-
func _vjpAcosh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
225+
func _${derivative_kind}Acosh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
185226
return (acosh(x), { v in v / sqrt(x * x - 1) })
186227
}
187228

188229
@inlinable
189230
@derivative(of: atanh)
190-
func _vjpAtanh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
231+
func _${derivative_kind}Atanh(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
191232
return (atanh(x), { v in v / (1 - x * x) })
192233
}
193234

194235
@inlinable
195236
@derivative(of: expm1)
196-
func _vjpExpm1(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
237+
func _${derivative_kind}Expm1(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
197238
return (expm1(x), { v in exp(x) * v })
198239
}
199240

200241
@inlinable
201242
@derivative(of: log1p)
202-
func _vjpLog1p(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
243+
func _${derivative_kind}Log1p(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
203244
return (log1p(x), { v in v / (x + 1) })
204245
}
205246

206247
@inlinable
207248
@derivative(of: erf)
208-
func _vjpErf(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
249+
func _${derivative_kind}Erf(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
209250
return (erf(x), { v in v * ${T}(M_2_SQRTPI) * exp(-x * x) })
210251
}
211252

212253
@inlinable
213254
@derivative(of: erfc)
214-
func _vjpErfc(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
255+
func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) {
215256
return (erfc(x), { v in v * -${T}(M_2_SQRTPI) * exp(-x * x) })
216257
}
217258

218-
% if T == 'Float80':
259+
% if T == 'Float80':
219260
#endif
220-
% end
221-
%end
261+
% end # if T == 'Float80':
262+
% end # for T in ['Float', 'Double', 'Float80']:
263+
%end # for derivative_kind in ['jvp', 'vjp']:

test/AutoDiff/downstream/forward_mode_runtime.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target_run_simple_swift_forward_mode_differentiation
1+
// RUN: %target-run-simple-swift-forward-mode-differentiation
22
// REQUIRES: executable_test
33

44
import StdlibUnittest

test/AutoDiff/downstream/nonvaried_result.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %target-swift-frontend -Xllvm -sil-print-after=differentiation %s -emit-sil -o /dev/null 2>&1 | %FileCheck %s
22
// RUN: %target-run-simple-swift
33
// TODO: Test forward-mode differentiation when it supports control flow.
4-
// UN: %target_run_simple_swift_forward_mode_differentiation
4+
// UN: %target-run-simple-swift-forward-mode-differentiation
55
// REQUIRES: executable_test
66

77
// Test differentiation edge case: functions with non-varied results.

test/AutoDiff/downstream/simple_math.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: %target-run-simple-swift
22
// NOTE(TF-813): verify that enabling forward-mode does not affect reverse-mode.
3-
// RUN: %target_run_simple_swift_forward_mode_differentiation
3+
// RUN: %target-run-simple-swift-forward-mode-differentiation
44
// RUN: %target-swift-frontend -Xllvm -sil-print-after=differentiation %s -emit-sil -o /dev/null 2>&1 | %FileCheck %s
55
// REQUIRES: executable_test
66

0 commit comments

Comments
 (0)