@@ -25,79 +25,67 @@ public struct BatchNormConv2DBlock: Layer {
25
25
public var conv1 : Conv2D < Float >
26
26
public var norm2 : BatchNorm < Float >
27
27
public var conv2 : Conv2D < Float >
28
+ public var shortcut : Conv2D < Float >
29
+ let isExpansion : Bool
30
+ let dropout : Dropout < Float > = Dropout ( probability: 0.3 )
28
31
29
32
public init (
30
- filterShape: ( Int , Int , Int , Int ) ,
33
+ featureCounts: ( Int , Int ) ,
34
+ kernelSize: Int = 3 ,
31
35
strides: ( Int , Int ) = ( 1 , 1 ) ,
32
36
padding: Padding = . same
33
37
) {
34
- self . norm1 = BatchNorm ( featureCount: filterShape. 2 )
35
- self . conv1 = Conv2D ( filterShape: filterShape, strides: strides, padding: padding)
36
- self . norm2 = BatchNorm ( featureCount: filterShape. 3 )
37
- self . conv2 = Conv2D ( filterShape: filterShape, strides: ( 1 , 1 ) , padding: padding)
38
+ self . norm1 = BatchNorm ( featureCount: featureCounts. 0 )
39
+ self . conv1 = Conv2D (
40
+ filterShape: ( kernelSize, kernelSize, featureCounts. 0 , featureCounts. 1 ) ,
41
+ strides: strides,
42
+ padding: padding)
43
+ self . norm2 = BatchNorm ( featureCount: featureCounts. 1 )
44
+ self . conv2 = Conv2D ( filterShape: ( kernelSize, kernelSize, featureCounts. 1 , featureCounts. 1 ) ,
45
+ strides: ( 1 , 1 ) ,
46
+ padding: padding)
47
+ self . shortcut = Conv2D ( filterShape: ( 1 , 1 , featureCounts. 0 , featureCounts. 1 ) ,
48
+ strides: strides,
49
+ padding: padding)
50
+ self . isExpansion = featureCounts. 1 != featureCounts. 0 || strides != ( 1 , 1 )
38
51
}
39
52
40
53
@differentiable
41
54
public func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
42
- let firstLayer = conv1 ( relu ( norm1 ( input) ) )
43
- return conv2 ( relu ( norm2 ( firstLayer) ) )
55
+ let preact1 = relu ( norm1 ( input) )
56
+ var residual = conv1 ( preact1)
57
+ let preact2 : Tensor < Float >
58
+ let shortcutResult : Tensor < Float >
59
+ if isExpansion {
60
+ shortcutResult = shortcut ( preact1)
61
+ preact2 = relu ( norm2 ( residual) )
62
+ } else {
63
+ shortcutResult = input
64
+ preact2 = dropout ( relu ( norm2 ( residual) ) )
65
+ }
66
+ residual = conv2 ( preact2)
67
+ return residual + shortcutResult
44
68
}
45
69
}
46
70
47
71
public struct WideResNetBasicBlock : Layer {
48
72
public var blocks : [ BatchNormConv2DBlock ]
49
- public var shortcut : Conv2D < Float >
50
73
51
74
public init (
52
75
featureCounts: ( Int , Int ) ,
53
76
kernelSize: Int = 3 ,
54
77
depthFactor: Int = 2 ,
55
- widenFactor: Int = 1 ,
56
78
initialStride: ( Int , Int ) = ( 2 , 2 )
57
79
) {
58
- if initialStride == ( 1 , 1 ) {
59
- self . blocks = [
60
- BatchNormConv2DBlock (
61
- filterShape: (
62
- kernelSize, kernelSize,
63
- featureCounts. 0 , featureCounts. 1 * widenFactor
64
- ) ,
65
- strides: initialStride)
66
- ]
67
- self . shortcut = Conv2D (
68
- filterShape: ( 1 , 1 , featureCounts. 0 , featureCounts. 1 * widenFactor) ,
69
- strides: initialStride)
70
- } else {
71
- self . blocks = [
72
- BatchNormConv2DBlock (
73
- filterShape: (
74
- kernelSize, kernelSize,
75
- featureCounts. 0 * widenFactor, featureCounts. 1 * widenFactor
76
- ) ,
77
- strides: initialStride)
78
- ]
79
- self . shortcut = Conv2D (
80
- filterShape: ( 1 , 1 , featureCounts. 0 * widenFactor, featureCounts. 1 * widenFactor) ,
81
- strides: initialStride)
82
- }
80
+ self . blocks = [ BatchNormConv2DBlock ( featureCounts: featureCounts, strides: initialStride) ]
83
81
for _ in 1 ..< depthFactor {
84
- self . blocks += [
85
- BatchNormConv2DBlock (
86
- filterShape: (
87
- kernelSize, kernelSize,
88
- featureCounts. 1 * widenFactor, featureCounts. 1 * widenFactor
89
- ) ,
90
- strides: ( 1 , 1 ) )
91
- ]
92
- }
82
+ self . blocks += [ BatchNormConv2DBlock ( featureCounts: ( featureCounts. 1 , featureCounts. 1 ) ) ]
83
+ }
93
84
}
94
85
95
86
@differentiable
96
87
public func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
97
- let blocksReduced = blocks. differentiableReduce ( input) { last, layer in
98
- relu ( layer ( last) )
99
- }
100
- return relu ( blocksReduced + shortcut( input) )
88
+ return blocks. differentiableReduce ( input) { $1 ( $0) }
101
89
}
102
90
}
103
91
@@ -116,15 +104,12 @@ public struct WideResNet: Layer {
116
104
public init ( depthFactor: Int = 2 , widenFactor: Int = 8 ) {
117
105
self . l1 = Conv2D ( filterShape: ( 3 , 3 , 3 , 16 ) , strides: ( 1 , 1 ) , padding: . same)
118
106
119
- l2 = WideResNetBasicBlock (
120
- featureCounts: ( 16 , 16 ) , depthFactor: depthFactor,
121
- widenFactor: widenFactor, initialStride: ( 1 , 1 ) )
122
- l3 = WideResNetBasicBlock (
123
- featureCounts: ( 16 , 32 ) , depthFactor: depthFactor,
124
- widenFactor: widenFactor)
125
- l4 = WideResNetBasicBlock (
126
- featureCounts: ( 32 , 64 ) , depthFactor: depthFactor,
127
- widenFactor: widenFactor)
107
+ self . l2 = WideResNetBasicBlock (
108
+ featureCounts: ( 16 , 16 * widenFactor) , depthFactor: depthFactor, initialStride: ( 1 , 1 ) )
109
+ self . l3 = WideResNetBasicBlock ( featureCounts: ( 16 * widenFactor, 32 * widenFactor) ,
110
+ depthFactor: depthFactor)
111
+ self . l4 = WideResNetBasicBlock ( featureCounts: ( 32 * widenFactor, 64 * widenFactor) ,
112
+ depthFactor: depthFactor)
128
113
129
114
self . norm = BatchNorm ( featureCount: 64 * widenFactor)
130
115
self . avgPool = AvgPool2D ( poolSize: ( 8 , 8 ) , strides: ( 8 , 8 ) )
0 commit comments