Skip to content

Commit 5eebdf8

Browse files
committed
[Swift] Remove custom derivative for LearningPhaseDependent.forward(_:).
Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Automatic differentiation can handle this enum `switch` now, so a custom derivative is no longer necessary. https://bugs.swift.org/browse/TF-1037 tracks this issue. Related discussion: swiftlang/swift#28621 (comment)
1 parent f2df87a commit 5eebdf8

File tree

8 files changed

+5
-130
lines changed

8 files changed

+5
-130
lines changed

swift/FastaiNotebook_07_batchnorm/Sources/FastaiNotebook_07_batchnorm/07_batchnorm.swift

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,20 @@ public protocol LearningPhaseDependent: FALayer {
1919
associatedtype Input
2020
associatedtype Output
2121

22-
@differentiable func forwardTraining(_ input: Input) -> Output
23-
@differentiable func forwardInference(_ input: Input) -> Output
22+
@differentiable
23+
func forwardTraining(_ input: Input) -> Output
24+
@differentiable
25+
func forwardInference(_ input: Input) -> Output
2426
}
2527

2628
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3129
@differentiable
3230
public func forward(_ input: Input) -> Output {
3331
switch Context.local.learningPhase {
3432
case .training: return forwardTraining(input)
3533
case .inference: return forwardInference(input)
3634
}
3735
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5036
}
5137

5238
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08_data_block/Sources/FastaiNotebook_08_data_block/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08a_heterogeneous_dictionary/Sources/FastaiNotebook_08a_heterogeneous_dictionary/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08c_data_block_generic/Sources/FastaiNotebook_08c_data_block_generic/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_09_optimizer/Sources/FastaiNotebook_09_optimizer/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_10_mixup_ls/Sources/FastaiNotebook_10_mixup_ls/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_11_imagenette/Sources/FastaiNotebook_11_imagenette/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/Runnable11/Sources/Runnable11/07_batchnorm.swift

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,13 @@ public protocol LearningPhaseDependent: FALayer {
2222
}
2323

2424
public extension LearningPhaseDependent {
25-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
26-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
27-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
28-
// some require it. Investigate.
29-
@differentiable(vjp: gradForward)
25+
@differentiable
3026
public func forward(_ input: Input) -> Output {
3127
switch Context.local.learningPhase {
3228
case .training: return forwardTraining(to: input)
3329
case .inference: return forwardInference(to: input)
3430
}
3531
}
36-
37-
func gradForward(_ input: Input) ->
38-
(Output, (Self.Output.TangentVector) ->
39-
(Self.TangentVector, Self.Input.TangentVector)) {
40-
switch Context.local.learningPhase {
41-
case .training:
42-
return valueWithPullback(at: input) { $0.forwardTraining(to: $1) }
43-
case .inference:
44-
return valueWithPullback(at: input) { $0.forwardInference(to: $1) }
45-
}
46-
}
4732
}
4833

4934
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

0 commit comments

Comments
 (0)