Skip to content

Commit b9b6899

Browse files
committed
Remove custom derivative for LearningPhaseDependent.forward(_:).
Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Related discussion: swiftlang/swift#28621 (comment) https://bugs.swift.org/browse/TF-1037 tracks this issue.
1 parent 6690c04 commit b9b6899

File tree

8 files changed

+2
-130
lines changed

8 files changed

+2
-130
lines changed

swift/FastaiNotebook_07_batchnorm/Sources/FastaiNotebook_07_batchnorm/07_batchnorm.swift

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,13 @@ public protocol LearningPhaseDependent: FALayer {
2626
}
2727

2828
extension LearningPhaseDependent {
29-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
30-
// protocol requirement, even though there is a `@derivative(of: forward)` method below.
31-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
32-
// some require it. Investigate.
33-
@differentiable(vjp: gradForward)
29+
@differentiable
3430
public func forward(_ input: Input) -> Output {
3531
switch Context.local.learningPhase {
3632
case .training: return forwardTraining(input)
3733
case .inference: return forwardInference(input)
3834
}
3935
}
40-
41-
@usableFromInline
42-
// @derivative(of: forward)
43-
func gradForward(_ input: Input) ->
44-
(value: Output, pullback: (Self.Output.TangentVector) ->
45-
(Self.TangentVector, Self.Input.TangentVector)) {
46-
switch Context.local.learningPhase {
47-
case .training:
48-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
49-
case .inference:
50-
return valueWithPullback(at: input) { $0.forwardInference($1) }
51-
}
52-
}
5336
}
5437

5538
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08_data_block/Sources/FastaiNotebook_08_data_block/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08a_heterogeneous_dictionary/Sources/FastaiNotebook_08a_heterogeneous_dictionary/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08c_data_block_generic/Sources/FastaiNotebook_08c_data_block_generic/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_09_optimizer/Sources/FastaiNotebook_09_optimizer/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_10_mixup_ls/Sources/FastaiNotebook_10_mixup_ls/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_11_imagenette/Sources/FastaiNotebook_11_imagenette/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/Runnable11/Sources/Runnable11/07_batchnorm.swift

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,13 @@ public protocol LearningPhaseDependent: FALayer {
2222
}
2323

2424
public extension LearningPhaseDependent {
25-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
26-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
27-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
28-
// some require it. Investigate.
29-
@differentiable(vjp: gradForward)
25+
@differentiable
3026
public func forward(_ input: Input) -> Output {
3127
switch Context.local.learningPhase {
3228
case .training: return forwardTraining(to: input)
3329
case .inference: return forwardInference(to: input)
3430
}
3531
}
36-
37-
func gradForward(_ input: Input) ->
38-
(Output, (Self.Output.TangentVector) ->
39-
(Self.TangentVector, Self.Input.TangentVector)) {
40-
switch Context.local.learningPhase {
41-
case .training:
42-
return valueWithPullback(at: input) { $0.forwardTraining(to: $1) }
43-
case .inference:
44-
return valueWithPullback(at: input) { $0.forwardInference(to: $1) }
45-
}
46-
}
4732
}
4833

4934
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

0 commit comments

Comments
 (0)