Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit edd734e

Browse files
brettkooncesaeta
authored andcommitted
rebuild resnet block based approach (#170)
1 parent cf21585 commit edd734e

File tree

3 files changed

+180
-229
lines changed

3 files changed

+180
-229
lines changed

ResNet/Helpers.swift

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// TODO: Remove this when it's moved to the standard library.
16+
extension Array where Element: Differentiable {
17+
@differentiable(wrt: (self, initialResult), vjp: reduceDerivative)
18+
func differentiableReduce<Result: Differentiable>(
19+
_ initialResult: Result,
20+
_ nextPartialResult: @differentiable (Result, Element) -> Result
21+
) -> Result {
22+
return reduce(initialResult, nextPartialResult)
23+
}
24+
25+
func reduceDerivative<Result: Differentiable>(
26+
_ initialResult: Result,
27+
_ nextPartialResult: @differentiable (Result, Element) -> Result
28+
) -> (Result, (Result.TangentVector) -> (Array.TangentVector, Result.TangentVector)) {
29+
var pullbacks: [(Result.TangentVector)
30+
-> (Result.TangentVector, Element.TangentVector)] = []
31+
let count = self.count
32+
pullbacks.reserveCapacity(count)
33+
var result = initialResult
34+
for element in self {
35+
let (y, pb) = Swift.valueWithPullback(at: result, element, in: nextPartialResult)
36+
result = y
37+
pullbacks.append(pb)
38+
}
39+
return (value: result, pullback: { cotangent in
40+
var resultCotangent = cotangent
41+
var elementCotangents = TangentVector([])
42+
elementCotangents.base.reserveCapacity(count)
43+
for pullback in pullbacks.reversed() {
44+
let (newResultCotangent, elementCotangent) = pullback(resultCotangent)
45+
resultCotangent = newResultCotangent
46+
elementCotangents.base.append(elementCotangent)
47+
}
48+
return (TangentVector(elementCotangents.base.reversed()), resultCotangent)
49+
})
50+
}
51+
}

0 commit comments

Comments
 (0)