Skip to content

Commit aabbc18

Browse files
authored
[Swift] Remove custom derivative for LearningPhaseDependent.forward(_:). (#309)
Work around issues with `@differentiable` + `@derivative` attributes with different derivative generic signatures. Automatic differentiation can handle this enum `switch` now, so a custom derivative is no longer necessary. https://bugs.swift.org/browse/TF-1037 tracks this issue. Related discussion: swiftlang/swift#28621 (comment)
1 parent f2df87a commit aabbc18

File tree

9 files changed

+1
-144
lines changed

9 files changed

+1
-144
lines changed

swift/07_batchnorm.ipynb

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -306,29 +306,13 @@
306306
"}\n",
307307
"\n",
308308
"extension LearningPhaseDependent {\n",
309-
" // This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer\n",
310-
" // protocol requirement, even though there is a `@differentiating(forward)` method below.\n",
311-
" // TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,\n",
312-
" // some require it. Investigate.\n",
313309
" @differentiable\n",
314310
" public func forward(_ input: Input) -> Output {\n",
315311
" switch Context.local.learningPhase {\n",
316312
" case .training: return forwardTraining(input)\n",
317313
" case .inference: return forwardInference(input)\n",
318314
" }\n",
319315
" }\n",
320-
"\n",
321-
" @differentiating(forward)\n",
322-
" func gradForward(_ input: Input) ->\n",
323-
" (value: Output, pullback: (Self.Output.TangentVector) ->\n",
324-
" (Self.TangentVector, Self.Input.TangentVector)) {\n",
325-
" switch Context.local.learningPhase {\n",
326-
" case .training:\n",
327-
" return valueWithPullback(at: input) { $0.forwardTraining ($1) }\n",
328-
" case .inference:\n",
329-
" return valueWithPullback(at: input) { $0.forwardInference($1) }\n",
330-
" }\n",
331-
" }\n",
332316
"}"
333317
]
334318
},

swift/FastaiNotebook_07_batchnorm/Sources/FastaiNotebook_07_batchnorm/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08_data_block/Sources/FastaiNotebook_08_data_block/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08a_heterogeneous_dictionary/Sources/FastaiNotebook_08a_heterogeneous_dictionary/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_08c_data_block_generic/Sources/FastaiNotebook_08c_data_block_generic/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_09_optimizer/Sources/FastaiNotebook_09_optimizer/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_10_mixup_ls/Sources/FastaiNotebook_10_mixup_ls/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/FastaiNotebook_11_imagenette/Sources/FastaiNotebook_11_imagenette/07_batchnorm.swift

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,13 @@ public protocol LearningPhaseDependent: FALayer {
2424
}
2525

2626
extension LearningPhaseDependent {
27-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
28-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
29-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
30-
// some require it. Investigate.
3127
@differentiable
3228
public func forward(_ input: Input) -> Output {
3329
switch Context.local.learningPhase {
3430
case .training: return forwardTraining(input)
3531
case .inference: return forwardInference(input)
3632
}
3733
}
38-
39-
@differentiating(forward)
40-
func gradForward(_ input: Input) ->
41-
(value: Output, pullback: (Self.Output.TangentVector) ->
42-
(Self.TangentVector, Self.Input.TangentVector)) {
43-
switch Context.local.learningPhase {
44-
case .training:
45-
return valueWithPullback(at: input) { $0.forwardTraining ($1) }
46-
case .inference:
47-
return valueWithPullback(at: input) { $0.forwardInference($1) }
48-
}
49-
}
5034
}
5135

5236
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

swift/Runnable11/Sources/Runnable11/07_batchnorm.swift

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,13 @@ public protocol LearningPhaseDependent: FALayer {
2222
}
2323

2424
public extension LearningPhaseDependent {
25-
// This `@differentiable` attribute is necessary, to tell the compiler that this satisfies the FALayer
26-
// protocol requirement, even though there is a `@differentiating(forward)` method below.
27-
// TODO: It seems nondeterministically necessary. Some subsequent notebooks import this successfully without it,
28-
// some require it. Investigate.
29-
@differentiable(vjp: gradForward)
25+
@differentiable
3026
public func forward(_ input: Input) -> Output {
3127
switch Context.local.learningPhase {
3228
case .training: return forwardTraining(to: input)
3329
case .inference: return forwardInference(to: input)
3430
}
3531
}
36-
37-
func gradForward(_ input: Input) ->
38-
(Output, (Self.Output.TangentVector) ->
39-
(Self.TangentVector, Self.Input.TangentVector)) {
40-
switch Context.local.learningPhase {
41-
case .training:
42-
return valueWithPullback(at: input) { $0.forwardTraining(to: $1) }
43-
case .inference:
44-
return valueWithPullback(at: input) { $0.forwardInference(to: $1) }
45-
}
46-
}
4732
}
4833

4934
public protocol Norm: Layer where Input == Tensor<Scalar>, Output == Tensor<Scalar>{

0 commit comments

Comments
 (0)