Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 939de51

Browse files
committed
Support ResNet-style models with skip connections
1 parent 5ed492b commit 939de51

File tree

6 files changed

+227
-1
lines changed

6 files changed

+227
-1
lines changed

Examples/ResNet-CIFAR10/main.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ let batchSize = 10
2121
let dataset = CIFAR10(batchSize: batchSize)
2222

2323
// Use the network sized for CIFAR-10
24-
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
24+
var model = autoResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
25+
.buildModel(inputShape: (32, 32, 3))
2526

2627
// the classic ImageNet optimizer setting diverges on CIFAR-10
2728
// let optimizer = SGD(for: model, learningRate: 0.1, momentum: 0.9)
@@ -33,6 +34,7 @@ for (epoch, epochBatches) in dataset.training.prefix(10).enumerated() {
3334
Context.local.learningPhase = .training
3435
var trainingLossSum: Float = 0
3536
var trainingBatchCount = 0
37+
let batchCount = epochBatches.count
3638
for batch in epochBatches {
3739
let (images, labels) = (batch.data, batch.label)
3840
let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
@@ -42,6 +44,7 @@ for (epoch, epochBatches) in dataset.training.prefix(10).enumerated() {
4244
trainingLossSum += loss.scalarized()
4345
trainingBatchCount += 1
4446
optimizer.update(&model, along: gradients)
47+
print(trainingLossSum / Float(trainingBatchCount), Float(trainingBatchCount) / Float(batchCount))
4548
}
4649

4750
Context.local.learningPhase = .inference

Models/ImageClassification/ResNet.swift

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import TensorFlow
16+
import LayerInit
1617

1718
// Original Paper:
1819
// "Deep Residual Learning for Image Recognition"
@@ -24,6 +25,12 @@ import TensorFlow
2425
// The structure of this implementation was inspired by the Flax ResNet example:
2526
// https://github.com/google/flax/blob/master/examples/imagenet/models.py
2627

28+
public typealias AutoConvBN = AutoSequencedDefinition<AutoBatchNorm<(Int, Int, Int), Float>, AutoConv2D<Float>>
29+
public func autoConvBN(filterShape: (Int, Int), outputChannels: Int, strides: (Int, Int) = (1, 1), padding: Padding = .valid) -> AutoConvBN {
30+
return AutoBatchNorm<(Int, Int, Int), Float>(momentum: 0.9, epsilon: 1e-5)
31+
.then(AutoConv2D<Float>(filterShape: filterShape, outputChannels: outputChannels, strides: strides, padding: padding, useBias: false))
32+
}
33+
2734
public struct ConvBN: Layer {
2835
public var conv: Conv2D<Float>
2936
public var norm: BatchNorm<Float>
@@ -43,6 +50,59 @@ public struct ConvBN: Layer {
4350
}
4451
}
4552

53+
// TODO(shadaj): OH NO
54+
public typealias ConvPlusResidual = AutoSplitMerge<AutoSequencedMany<AutoConvBN>, AutoSequencedDefinition<AutoSequencedMany<AutoSequencedDefinition<AutoConvBN, AutoFunction<Tensor<Float>, Tensor<Float>, AutoConv2D<Float>.OutputShape, AutoConv2D<Float>.OutputShape>>>, AutoConvBN>, Tensor<Float>, AutoBatchNorm<(Int, Int, Int), Float>.InputShape>
55+
public typealias AutoResidualBlock = AutoSequencedDefinition<ConvPlusResidual, AutoFunction<Tensor<Float>, Tensor<Float>, (Int, Int, Int), (Int, Int, Int)>>
56+
public func autoResidualBlock(inputFilters: Int, filters: Int, strides: (Int, Int), useLaterStride: Bool, isBasic: Bool) -> AutoResidualBlock {
57+
let outFilters = filters * (isBasic ? 1 : 4)
58+
let needsProjection = (inputFilters != outFilters) || (strides.0 != 1)
59+
60+
let projection = needsProjection
61+
? autoConvBN(filterShape: (1, 1), outputChannels: outFilters, strides: strides)
62+
: autoConvBN(filterShape: (1, 1), outputChannels: 1)
63+
64+
let residual = AutoSequencedMany(layers: needsProjection ? [projection]: [])
65+
66+
var earlyConvs: [AutoConvBN] = []
67+
let lastConv: AutoConvBN
68+
if isBasic {
69+
earlyConvs = [
70+
(autoConvBN(
71+
filterShape: (3, 3), outputChannels: filters, strides: strides, padding: .same)),
72+
]
73+
lastConv = autoConvBN(filterShape: (3, 3), outputChannels: outFilters, padding: .same)
74+
} else {
75+
if useLaterStride {
76+
// Configure for ResNet V1.5 (the more common implementation).
77+
earlyConvs.append(autoConvBN(filterShape: (1, 1), outputChannels: filters))
78+
earlyConvs.append(
79+
autoConvBN(filterShape: (3, 3), outputChannels: filters, strides: strides, padding: .same))
80+
} else {
81+
// Configure for ResNet V1 (the paper implementation).
82+
earlyConvs.append(
83+
autoConvBN(filterShape: (1, 1), outputChannels: filters, strides: strides))
84+
earlyConvs.append(autoConvBN(filterShape: (3, 3), outputChannels: filters, padding: .same))
85+
}
86+
lastConv = autoConvBN(filterShape: (1, 1), outputChannels: outFilters)
87+
}
88+
89+
let earlyConvsWithRelu = earlyConvs.map({ (conv) in
90+
conv.then(AutoFunction(fnShape: { $0 }, fn: { (prev: Tensor<Float>) in relu(prev) }))
91+
})
92+
93+
let lastConvResult = AutoSequencedMany(layers: earlyConvsWithRelu).then(lastConv)
94+
95+
96+
let convPlusResidual = AutoSplitMerge(
97+
layer1: residual,
98+
layer2: lastConvResult,
99+
mergeOutputShape: { (l1, l2) in l1 }, mergeFn: SplitMergeFunctionWrapper({ $0 + $1 }))
100+
101+
let finalResult = convPlusResidual.then(AutoFunction<Tensor<Float>, Tensor<Float>, (Int, Int, Int), (Int, Int, Int)>(fnShape: { $0 }, fn: { (prev: Tensor<Float>) in relu(prev) }))
102+
103+
return finalResult
104+
}
105+
46106
public struct ResidualBlock: Layer {
47107
public var projection: ConvBN
48108
@noDerivative public let needsProjection: Bool
@@ -103,6 +163,50 @@ public struct ResidualBlock: Layer {
103163
}
104164
}
105165

166+
public typealias AutoResNet = AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoConvBN, AutoFunction<Tensor<Float>, Tensor<Float>, AutoConv2D<Float>.OutputShape, AutoMaxPool2D<Float>.InputShape>>, AutoMaxPool2D<Float>>, AutoSequencedMany<AutoResidualBlock>>, AutoGlobalAvgPool2D<Float>>, AutoDense<Float>>
167+
public func autoResNet(
168+
classCount: Int, depth: ResNet.Depth, downsamplingInFirstStage: Bool = true,
169+
useLaterStride: Bool = true
170+
) -> AutoResNet {
171+
let initialLayer: AutoConvBN
172+
let maxPool: AutoMaxPool2D<Float>
173+
174+
let inputFilters: Int
175+
176+
if downsamplingInFirstStage {
177+
inputFilters = 64
178+
initialLayer = autoConvBN(
179+
filterShape: (7, 7), outputChannels: inputFilters, strides: (2, 2), padding: .same)
180+
maxPool = AutoMaxPool2D(poolSize: (3, 3), strides: (2, 2), padding: .same)
181+
} else {
182+
inputFilters = 16
183+
initialLayer = autoConvBN(
184+
filterShape: (3, 3), outputChannels: inputFilters, padding: .same)
185+
maxPool = AutoMaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
186+
}
187+
188+
var residualBlocks: [AutoResidualBlock] = []
189+
var lastInputFilterCount = inputFilters
190+
for (blockSizeIndex, blockSize) in depth.layerBlockSizes.enumerated() {
191+
for blockIndex in 0..<blockSize {
192+
let strides = ((blockSizeIndex > 0) && (blockIndex == 0)) ? (2, 2) : (1, 1)
193+
let filters = inputFilters * Int(pow(2.0, Double(blockSizeIndex)))
194+
let residualBlock = autoResidualBlock(
195+
inputFilters: lastInputFilterCount, filters: filters, strides: strides,
196+
useLaterStride: useLaterStride, isBasic: depth.usesBasicBlocks)//.buildModel(inputShape: (1, 1, lastInputFilterCount))
197+
lastInputFilterCount = filters * (depth.usesBasicBlocks ? 1 : 4)
198+
residualBlocks.append(residualBlock)
199+
}
200+
}
201+
202+
return initialLayer
203+
.then(AutoFunction(fnShape: { $0 }, fn: { (prev: Tensor<Float>) in relu(prev) }))
204+
.then(maxPool)
205+
.then(AutoSequencedMany(layers: residualBlocks))
206+
.then(AutoGlobalAvgPool2D())
207+
.then(AutoDense(outputSize: classCount))
208+
}
209+
106210
/// An implementation of the ResNet v1 and v1.5 architectures, at various depths.
107211
public struct ResNet: Layer {
108212
public var initialLayer: ConvBN

Models/LayerInit/AutoBatchNorm.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import TensorFlow
2+
3+
public struct AutoBatchNorm<Shape, Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint {
4+
let axis: Int
5+
let momentum: Scalar
6+
let epsilon: Scalar
7+
8+
public typealias InstanceType = BatchNorm<Scalar>
9+
public typealias InputShape = Shape
10+
public typealias OutputShape = Shape
11+
12+
public init(
13+
axis: Int = -1,
14+
momentum: Scalar = 0.99,
15+
epsilon: Scalar = 0.001
16+
) {
17+
self.axis = axis
18+
self.momentum = momentum
19+
self.epsilon = epsilon
20+
}
21+
22+
public func buildModelWithOutputShape(inputShape: Shape) -> (InstanceType, Shape) {
23+
let inputShapeArray: [Int]
24+
if let inputShapeTuple = inputShape as? (Int, Int, Int) {
25+
inputShapeArray = [inputShapeTuple.0, inputShapeTuple.1, inputShapeTuple.2]
26+
} else {
27+
fatalError("Could not extract out elements of shape")
28+
}
29+
30+
let featureCount = inputShapeArray[(inputShapeArray.count + axis) % inputShapeArray.count]
31+
return (BatchNorm<Scalar>(featureCount: featureCount, axis: axis, momentum: momentum, epsilon: epsilon), inputShape)
32+
}
33+
}

Models/LayerInit/AutoDense.swift

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,75 @@ public struct AutoDense<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint
1717
return (Dense<Scalar>(inputSize: inputShape, outputSize: self.outputSize, activation: self.activation), self.outputSize)
1818
}
1919
}
20+
21+
// Workaround https://bugs.swift.org/browse/TF-1122
22+
public final class SplitMergeFunctionWrapper<Output1: Differentiable, Output2: Differentiable, CommonOutput: Differentiable> {
23+
public typealias F = @differentiable (Output1, Output2) -> CommonOutput
24+
public var f: F
25+
public init(_ f: @escaping F) { self.f = f }
26+
}
27+
28+
public struct SplitMergeInstance<Layer1: Layer, Layer2: Layer, CommonOutput: Differentiable>: Layer
29+
where Layer1.Input == Layer2.Input, Layer1.TangentVector.VectorSpaceScalar == Layer2.TangentVector.VectorSpaceScalar {
30+
var layer1: Layer1
31+
var layer2: Layer2
32+
@noDerivative let mergeFn: SplitMergeFunctionWrapper<Layer1.Output, Layer2.Output, CommonOutput>
33+
34+
public init(layer1: Layer1, layer2: Layer2, mergeFn: SplitMergeFunctionWrapper<Layer1.Output, Layer2.Output, CommonOutput>) {
35+
self.layer1 = layer1
36+
self.layer2 = layer2
37+
self.mergeFn = mergeFn
38+
}
39+
40+
@differentiable
41+
public func callAsFunction(_ input: Layer1.Input) -> CommonOutput {
42+
let layer1Out = layer1(input)
43+
let layer2Out = layer2(input)
44+
return mergeFn.f(layer1Out, layer2Out)
45+
}
46+
}
47+
48+
public struct AutoSplitMerge<Layer1: AutoLayer, Layer2: AutoLayer, CommonOutput: Differentiable, OutputShape>: AutoLayer
49+
where Layer1.InputShape == Layer2.InputShape, Layer1.InstanceType.Input == Layer2.InstanceType.Input, Layer1.InstanceType.TangentVector.VectorSpaceScalar == Layer2.InstanceType.TangentVector.VectorSpaceScalar {
50+
let layer1: Layer1
51+
let layer2: Layer2
52+
53+
let mergeOutputShape: (Layer1.OutputShape, Layer2.OutputShape) -> OutputShape
54+
let mergeFn: SplitMergeFunctionWrapper<Layer1.InstanceType.Output, Layer2.InstanceType.Output, CommonOutput>
55+
56+
public typealias InstanceType = SplitMergeInstance<Layer1.InstanceType, Layer2.InstanceType, CommonOutput>
57+
public typealias InputShape = Layer1.InputShape
58+
public typealias OutputShape = OutputShape
59+
60+
public init(layer1: Layer1, layer2: Layer2, mergeOutputShape: @escaping (Layer1.OutputShape, Layer2.OutputShape) -> OutputShape, mergeFn: SplitMergeFunctionWrapper<Layer1.InstanceType.Output, Layer2.InstanceType.Output, CommonOutput>) {
61+
self.layer1 = layer1
62+
self.layer2 = layer2
63+
self.mergeOutputShape = mergeOutputShape
64+
self.mergeFn = mergeFn
65+
}
66+
67+
public func buildModelWithOutputShape(inputShape: Layer1.InputShape) -> (InstanceType, OutputShape) {
68+
let (layer1Built, layer1OutputShape) = layer1.buildModelWithOutputShape(inputShape: inputShape)
69+
let (layer2Built, layer2OutputShape) = layer2.buildModelWithOutputShape(inputShape: inputShape)
70+
return (SplitMergeInstance(layer1: layer1Built, layer2: layer2Built, mergeFn: self.mergeFn), self.mergeOutputShape(layer1OutputShape, layer2OutputShape))
71+
}
72+
}
73+
74+
75+
public struct AutoFunction<Input: Differentiable, Output: Differentiable, InputShape, OutputShape>: AutoLayer {
76+
let fnShape: (InputShape) -> OutputShape
77+
let fn: @differentiable (Input) -> Output
78+
79+
public typealias InstanceType = Function<Input, Output>
80+
public typealias InputShape = InputShape
81+
public typealias OutputShape = OutputShape
82+
83+
public init(fnShape: @escaping (InputShape) -> OutputShape, fn: @escaping @differentiable (Input) -> Output) {
84+
self.fnShape = fnShape
85+
self.fn = fn
86+
}
87+
88+
public func buildModelWithOutputShape(inputShape: InputShape) -> (InstanceType, OutputShape) {
89+
return (Function(fn), fnShape(inputShape))
90+
}
91+
}

Models/LayerInit/AutoPool.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ public struct AutoAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingP
4343
}
4444
}
4545

46+
public struct AutoGlobalAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint {
47+
public typealias InstanceType = GlobalAvgPool2D<Scalar>
48+
public typealias InputShape = (Int, Int, Int)
49+
public typealias OutputShape = Int
50+
51+
public init() {
52+
}
53+
54+
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, Int) {
55+
return (GlobalAvgPool2D<Scalar>(), inputShape.2)
56+
}
57+
}
58+
4659
public struct AutoMaxPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint {
4760
let poolSize: (Int, Int)
4861
let strides: (Int, Int)

Models/LayerInit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_library(LayerInit
2+
AutoBatchNorm.swift
23
AutoConv.swift
34
AutoDense.swift
45
AutoFlatten.swift

0 commit comments

Comments
 (0)