Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 36b75a7

Browse files
authored
Make derivatives of public functions be @usableFromInline. (#929)
Prepare for compiler change: `@derivative` functions and their original functions must have the same access level. This should improve stability and eliminate AutoDiff symbol linkage issues.
1 parent 1337cf7 commit 36b75a7

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

Sources/TensorFlow/Core/MixedPrecision.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,13 @@ extension Tensor {
206206
#endif
207207

208208
extension Tensor where Scalar: TensorFlowFloatingPoint {
209+
@usableFromInline
209210
@derivative(of: toReducedPrecision)
210211
func _vjpToReducedPrecision() -> (value: Tensor, pullback: (Tensor) -> Tensor) {
211212
(toReducedPrecision, { $0.toFullPrecision })
212213
}
213214

215+
@usableFromInline
214216
@derivative(of: toFullPrecision)
215217
func _vjpToFullPrecision() -> (value: Tensor, pullback: (Tensor) -> Tensor) {
216218
(toFullPrecision, { $0.toReducedPrecision })

Sources/TensorFlow/Layer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ extension Layer {
9494

9595
// TODO(TF-433, SR-11882): Remove this custom derivative when
9696
// differentiation supports `rethrows` functions and currying.
97-
@derivative(of: inferring(from:))
9897
@usableFromInline
98+
@derivative(of: inferring(from:))
9999
internal func _vjpInferring(from input: Input)
100100
-> (
101101
value: Output,

Sources/TensorFlow/Operators/NN.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ public func depthToSpace<Scalar>(_ input: Tensor<Scalar>, blockSize b: Int) -> T
971971
return _Raw.depthToSpace(input, blockSize: Int64(b))
972972
}
973973

974+
@usableFromInline
974975
@derivative(of: depthToSpace)
975976
func _vjpDepthToSpace<Scalar: TensorFlowFloatingPoint>(
976977
_ input: Tensor<Scalar>,
@@ -1048,6 +1049,7 @@ public func spaceToDepth<Scalar>(_ input: Tensor<Scalar>, blockSize b: Int) -> T
10481049
return _Raw.spaceToDepth(input, blockSize: Int64(b))
10491050
}
10501051

1052+
@usableFromInline
10511053
@derivative(of: spaceToDepth)
10521054
func _vjpSpaceToDepth<Scalar: TensorFlowFloatingPoint>(
10531055
_ input: Tensor<Scalar>,

0 commit comments

Comments
 (0)