Skip to content

Commit 155b5ed

Browse files
authored
[AutoDiff] Use @derivative attribute to register tgmath derivatives. (#28933)
- Compile stdlib with `-enable-experimental-cross-file-derivative-registration`. - This is necessary for using `@derivative` with tgmath functions, since some original functions (e.g. the `tan(_: Double) -> Double`) are imported from Clang. - Replace `@differentiable(jvp:vjp:)` with `@derivative` for tgmath functions. Progress towards TF-1085: using `@derivative` attribute in the stdlib.
1 parent fe6998d commit 155b5ed

File tree

3 files changed

+77
-45
lines changed

3 files changed

+77
-45
lines changed

cmake/modules/AddSwift.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,6 +1910,11 @@ function(add_swift_target_library name)
19101910
endif()
19111911
endif()
19121912

1913+
# SWIFT_ENABLE_TENSORFLOW
1914+
# NOTE(TF-1021): Enable cross-file derivative registration for stdlib.
1915+
list(APPEND swiftlib_swift_compile_flags_all
1916+
-Xllvm -enable-experimental-cross-file-derivative-registration)
1917+
# SWIFT_ENABLE_TENSORFLOW END
19131918

19141919
# Collect architecture agnostic SDK linker flags
19151920
set(swiftlib_link_flags_all ${SWIFTLIB_LINK_FLAGS})

stdlib/public/Platform/tgmath.swift.gyb

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -167,114 +167,136 @@ public func ldexp(_ x: ${T}, _ n : Int) -> ${T} {
167167
% if T == 'Float80':
168168
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
169169
% end
170-
@usableFromInline
171-
func _vjpExp(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
170+
@inlinable
171+
@derivative(of: exp)
172+
func _vjpExp(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
172173
let value = exp(x)
173174
return (value, { v in value * v })
174175
}
175176

176-
@usableFromInline
177-
func _vjpExp2(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
177+
@inlinable
178+
@derivative(of: exp2)
179+
func _vjpExp2(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
178180
let value = exp2(x)
179181
return (value, { v in v * ${T}(M_LN2) * value })
180182
}
181183

182-
@usableFromInline
183-
func _vjpLog(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
184+
@inlinable
185+
@derivative(of: log)
186+
func _vjpLog(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
184187
return (log(x), { v in v / x })
185188
}
186189

187-
@usableFromInline
188-
func _vjpLog10(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
190+
@inlinable
191+
@derivative(of: log10)
192+
func _vjpLog10(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
189193
return (log10(x), { v in v * ${T}(M_LOG10E) / x })
190194
}
191195

192-
@usableFromInline
193-
func _vjpLog2(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
196+
@inlinable
197+
@derivative(of: log2)
198+
func _vjpLog2(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
194199
return (log2(x), { v in v / (${T}(M_LN2) * x) })
195200
}
196201

197-
@usableFromInline
198-
func _vjpSin(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
202+
@inlinable
203+
@derivative(of: sin)
204+
func _vjpSin(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
199205
return (sin(x), { v in v * cos(x) })
200206
}
201207

202-
@usableFromInline
203-
func _vjpCos(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
208+
@inlinable
209+
@derivative(of: cos)
210+
func _vjpCos(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
204211
return (cos(x), { v in -v * sin(x) })
205212
}
206213

207-
@usableFromInline
208-
func _vjpTan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
214+
@inlinable
215+
@derivative(of: tan)
216+
func _vjpTan(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
209217
let value = tan(x)
210218
return (value, { v in v * (1 + value * value) })
211219
}
212220

213-
@usableFromInline
214-
func _vjpAsin(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
221+
@inlinable
222+
@derivative(of: asin)
223+
func _vjpAsin(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
215224
return (asin(x), { v in v / sqrt(1 - x * x) })
216225
}
217226

218-
@usableFromInline
219-
func _vjpAcos(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
227+
@inlinable
228+
@derivative(of: acos)
229+
func _vjpAcos(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
220230
return (acos(x), { v in -v / sqrt(1 - x * x) })
221231
}
222232

223-
@usableFromInline
224-
func _vjpAtan(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
233+
@inlinable
234+
@derivative(of: atan)
235+
func _vjpAtan(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
225236
return (atan(x), { v in v / (1 + x * x) })
226237
}
227238

228-
@usableFromInline
229-
func _vjpSinh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
239+
@inlinable
240+
@derivative(of: sinh)
241+
func _vjpSinh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
230242
return (sinh(x), { v in v * cosh(x) })
231243
}
232244

233-
@usableFromInline
234-
func _vjpCosh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
245+
@inlinable
246+
@derivative(of: cosh)
247+
func _vjpCosh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
235248
return (cosh(x), { v in v * sinh(x) })
236249
}
237250

238-
@usableFromInline
239-
func _vjpTanh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
251+
@inlinable
252+
@derivative(of: tanh)
253+
func _vjpTanh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
240254
let value = tanh(x)
241255
return (value, { v in v * (1 - value * value) })
242256
}
243257

244-
@usableFromInline
245-
func _vjpAsinh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
258+
@inlinable
259+
@derivative(of: asinh)
260+
func _vjpAsinh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
246261
return (asinh(x), { v in v / sqrt(1 + x * x) })
247262
}
248263

249-
@usableFromInline
250-
func _vjpAcosh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
264+
@inlinable
265+
@derivative(of: acosh)
266+
func _vjpAcosh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
251267
return (acosh(x), { v in v / sqrt(x * x - 1) })
252268
}
253269

254-
@usableFromInline
255-
func _vjpAtanh(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
270+
@inlinable
271+
@derivative(of: atanh)
272+
func _vjpAtanh(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
256273
return (atanh(x), { v in v / (1 - x * x) })
257274
}
258275

259-
@usableFromInline
260-
func _vjpExpm1(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
276+
@inlinable
277+
@derivative(of: expm1)
278+
func _vjpExpm1(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
261279
return (expm1(x), { v in exp(x) * v })
262280
}
263281

264-
@usableFromInline
265-
func _vjpLog1p(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
282+
@inlinable
283+
@derivative(of: log1p)
284+
func _vjpLog1p(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
266285
return (log1p(x), { v in v / (x + 1) })
267286
}
268287

269-
@usableFromInline
270-
func _vjpErf(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
288+
@inlinable
289+
@derivative(of: erf)
290+
func _vjpErf(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
271291
return (erf(x), { v in v * ${T}(M_2_SQRTPI) * exp(-x * x) })
272292
}
273293

274-
@usableFromInline
275-
func _vjpErfc(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
294+
@inlinable
295+
@derivative(of: erfc)
296+
func _vjpErfc(_ x: ${T}) -> (value: ${T}, pullback: (${T}) -> ${T}) {
276297
return (erfc(x), { v in v * -${T}(M_2_SQRTPI) * exp(-x * x) })
277298
}
299+
278300
// SWIFT_ENABLE_TENSORFLOW END
279301
% if T == 'Float80':
280302
#endif
@@ -398,7 +420,7 @@ def TypedBinaryFunctions():
398420
@_transparent
399421
// SWIFT_ENABLE_TENSORFLOW
400422
% if ufunc in HasVJP:
401-
@differentiable(vjp: _vjp${ufunc.capitalize()}(_:))
423+
@differentiable
402424
% end
403425
public func ${ufunc}(_ x: ${T}) -> ${T} {
404426
return ${T}.${ufunc}(x)
@@ -452,7 +474,7 @@ public func tgamma(_ x: Float80) -> Float80 {
452474
@_transparent
453475
// SWIFT_ENABLE_TENSORFLOW
454476
% if ufunc in HasVJP:
455-
@differentiable(vjp: _vjp${ufunc.capitalize()}(_:))
477+
@differentiable
456478
% end
457479
public func ${ufunc}(_ x: ${T}) -> ${T} {
458480
return ${T}.${ufunc}(x)
@@ -462,7 +484,7 @@ public func ${ufunc}(_ x: ${T}) -> ${T} {
462484
@_transparent
463485
// SWIFT_ENABLE_TENSORFLOW
464486
% if ufunc in HasVJP:
465-
@differentiable(vjp: _vjp${ufunc.capitalize()}(_:))
487+
@differentiable
466488
% end
467489
public func ${ufunc}(_ x: ${T}) -> ${T} {
468490
return x.rounded(.toNearestOrEven)

validation-test/ParseableInterface/verify_all_overlays.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
# RUN: %{python} %s %target-os %target-cpu %platform-sdk-overlay-dir %t \
88
# RUN: %target-swift-frontend -build-module-from-parseable-interface \
99
# RUN: -Fsystem %sdk/System/Library/PrivateFrameworks/ \
10+
# SWIFT_ENABLE_TENSORFLOW
11+
# NOTE(TF-1097): Remove flag when retroactive derivative registration is enabled
12+
# by default.
13+
# RUN: -Xllvm -enable-experimental-cross-file-derivative-registration \
14+
# SWIFT_ENABLE_TENSORFLOW END
1015
# RUN: | sort > %t/failures.txt
1116
# RUN: grep '# %target-os:' %s > %t/filter.txt || true
1217
# RUN: test ! -e %t/failures.txt || \

0 commit comments

Comments
 (0)