13
13
// limitations under the License.
14
14
15
15
import TensorFlow
16
+ import LayerInit
16
17
17
18
// Original Paper:
18
19
// "Deep Residual Learning for Image Recognition"
@@ -24,6 +25,12 @@ import TensorFlow
24
25
// The structure of this implementation was inspired by the Flax ResNet example:
25
26
// https://github.com/google/flax/blob/master/examples/imagenet/models.py
26
27
28
+ public typealias AutoConvBN = AutoSequencedDefinition < AutoBatchNorm < ( Int , Int , Int ) , Float > , AutoConv2D < Float > >
29
+ public func autoConvBN( filterShape: ( Int , Int ) , outputChannels: Int , strides: ( Int , Int ) = ( 1 , 1 ) , padding: Padding = . valid) -> AutoConvBN {
30
+ return AutoBatchNorm < ( Int , Int , Int ) , Float > ( momentum: 0.9 , epsilon: 1e-5 )
31
+ . then ( AutoConv2D < Float > ( filterShape: filterShape, outputChannels: outputChannels, strides: strides, padding: padding, useBias: false ) )
32
+ }
33
+
27
34
public struct ConvBN : Layer {
28
35
public var conv : Conv2D < Float >
29
36
public var norm : BatchNorm < Float >
@@ -43,6 +50,59 @@ public struct ConvBN: Layer {
43
50
}
44
51
}
45
52
53
+ // TODO(shadaj): OH NO
54
+ public typealias ConvPlusResidual = AutoSplitMerge < AutoSequencedMany < AutoConvBN > , AutoSequencedDefinition < AutoSequencedMany < AutoSequencedDefinition < AutoConvBN , AutoFunction < Tensor < Float > , Tensor < Float > , AutoConv2D < Float > . OutputShape , AutoConv2D < Float > . OutputShape > > > , AutoConvBN > , Tensor < Float > , AutoBatchNorm < ( Int , Int , Int ) , Float > . InputShape >
55
+ public typealias AutoResidualBlock = AutoSequencedDefinition < ConvPlusResidual , AutoFunction < Tensor < Float > , Tensor < Float > , ( Int , Int , Int ) , ( Int , Int , Int ) > >
56
+ public func autoResidualBlock( inputFilters: Int , filters: Int , strides: ( Int , Int ) , useLaterStride: Bool , isBasic: Bool ) -> AutoResidualBlock {
57
+ let outFilters = filters * ( isBasic ? 1 : 4 )
58
+ let needsProjection = ( inputFilters != outFilters) || ( strides. 0 != 1 )
59
+
60
+ let projection = needsProjection
61
+ ? autoConvBN ( filterShape: ( 1 , 1 ) , outputChannels: outFilters, strides: strides)
62
+ : autoConvBN ( filterShape: ( 1 , 1 ) , outputChannels: 1 )
63
+
64
+ let residual = AutoSequencedMany ( layers: needsProjection ? [ projection] : [ ] )
65
+
66
+ var earlyConvs : [ AutoConvBN ] = [ ]
67
+ let lastConv : AutoConvBN
68
+ if isBasic {
69
+ earlyConvs = [
70
+ ( autoConvBN (
71
+ filterShape: ( 3 , 3 ) , outputChannels: filters, strides: strides, padding: . same) ) ,
72
+ ]
73
+ lastConv = autoConvBN ( filterShape: ( 3 , 3 ) , outputChannels: outFilters, padding: . same)
74
+ } else {
75
+ if useLaterStride {
76
+ // Configure for ResNet V1.5 (the more common implementation).
77
+ earlyConvs. append ( autoConvBN ( filterShape: ( 1 , 1 ) , outputChannels: filters) )
78
+ earlyConvs. append (
79
+ autoConvBN ( filterShape: ( 3 , 3 ) , outputChannels: filters, strides: strides, padding: . same) )
80
+ } else {
81
+ // Configure for ResNet V1 (the paper implementation).
82
+ earlyConvs. append (
83
+ autoConvBN ( filterShape: ( 1 , 1 ) , outputChannels: filters, strides: strides) )
84
+ earlyConvs. append ( autoConvBN ( filterShape: ( 3 , 3 ) , outputChannels: filters, padding: . same) )
85
+ }
86
+ lastConv = autoConvBN ( filterShape: ( 1 , 1 ) , outputChannels: outFilters)
87
+ }
88
+
89
+ let earlyConvsWithRelu = earlyConvs. map ( { ( conv) in
90
+ conv. then ( AutoFunction ( fnShape: { $0 } , fn: { ( prev: Tensor < Float > ) in relu ( prev) } ) )
91
+ } )
92
+
93
+ let lastConvResult = AutoSequencedMany ( layers: earlyConvsWithRelu) . then ( lastConv)
94
+
95
+
96
+ let convPlusResidual = AutoSplitMerge (
97
+ layer1: residual,
98
+ layer2: lastConvResult,
99
+ mergeOutputShape: { ( l1, l2) in l1 } , mergeFn: SplitMergeFunctionWrapper ( { $0 + $1 } ) )
100
+
101
+ let finalResult = convPlusResidual. then ( AutoFunction < Tensor < Float > , Tensor < Float > , ( Int , Int , Int ) , ( Int , Int , Int ) > ( fnShape: { $0 } , fn: { ( prev: Tensor < Float > ) in relu ( prev) } ) )
102
+
103
+ return finalResult
104
+ }
105
+
46
106
public struct ResidualBlock : Layer {
47
107
public var projection : ConvBN
48
108
@noDerivative public let needsProjection : Bool
@@ -103,6 +163,50 @@ public struct ResidualBlock: Layer {
103
163
}
104
164
}
105
165
166
+ public typealias AutoResNet = AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoConvBN , AutoFunction < Tensor < Float > , Tensor < Float > , AutoConv2D < Float > . OutputShape , AutoMaxPool2D < Float > . InputShape > > , AutoMaxPool2D < Float > > , AutoSequencedMany < AutoResidualBlock > > , AutoGlobalAvgPool2D < Float > > , AutoDense < Float > >
167
+ public func autoResNet(
168
+ classCount: Int , depth: ResNet . Depth , downsamplingInFirstStage: Bool = true ,
169
+ useLaterStride: Bool = true
170
+ ) -> AutoResNet {
171
+ let initialLayer : AutoConvBN
172
+ let maxPool : AutoMaxPool2D < Float >
173
+
174
+ let inputFilters : Int
175
+
176
+ if downsamplingInFirstStage {
177
+ inputFilters = 64
178
+ initialLayer = autoConvBN (
179
+ filterShape: ( 7 , 7 ) , outputChannels: inputFilters, strides: ( 2 , 2 ) , padding: . same)
180
+ maxPool = AutoMaxPool2D ( poolSize: ( 3 , 3 ) , strides: ( 2 , 2 ) , padding: . same)
181
+ } else {
182
+ inputFilters = 16
183
+ initialLayer = autoConvBN (
184
+ filterShape: ( 3 , 3 ) , outputChannels: inputFilters, padding: . same)
185
+ maxPool = AutoMaxPool2D ( poolSize: ( 1 , 1 ) , strides: ( 1 , 1 ) ) // no-op
186
+ }
187
+
188
+ var residualBlocks : [ AutoResidualBlock ] = [ ]
189
+ var lastInputFilterCount = inputFilters
190
+ for (blockSizeIndex, blockSize) in depth. layerBlockSizes. enumerated ( ) {
191
+ for blockIndex in 0 ..< blockSize {
192
+ let strides = ( ( blockSizeIndex > 0 ) && ( blockIndex == 0 ) ) ? ( 2 , 2 ) : ( 1 , 1 )
193
+ let filters = inputFilters * Int( pow ( 2.0 , Double ( blockSizeIndex) ) )
194
+ let residualBlock = autoResidualBlock (
195
+ inputFilters: lastInputFilterCount, filters: filters, strides: strides,
196
+ useLaterStride: useLaterStride, isBasic: depth. usesBasicBlocks) //.buildModel(inputShape: (1, 1, lastInputFilterCount))
197
+ lastInputFilterCount = filters * ( depth. usesBasicBlocks ? 1 : 4 )
198
+ residualBlocks. append ( residualBlock)
199
+ }
200
+ }
201
+
202
+ return initialLayer
203
+ . then ( AutoFunction ( fnShape: { $0 } , fn: { ( prev: Tensor < Float > ) in relu ( prev) } ) )
204
+ . then ( maxPool)
205
+ . then ( AutoSequencedMany ( layers: residualBlocks) )
206
+ . then ( AutoGlobalAvgPool2D ( ) )
207
+ . then ( AutoDense ( outputSize: classCount) )
208
+ }
209
+
106
210
/// An implementation of the ResNet v1 and v1.5 architectures, at various depths.
107
211
public struct ResNet : Layer {
108
212
public var initialLayer : ConvBN
0 commit comments