Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit e319a07

Browse files
mikowalsBradLarson
authored andcommitted
WideResNet - fix widenFactor and match model to citation (#193)
* add identity connections to WideResNet * rename preact1 for clarity * remove extra relu, add dropout * fix declarartion * skip dropout in expansion blocks * remove enum, res layers to one line
1 parent 66d442d commit e319a07

File tree

1 file changed

+41
-56
lines changed

1 file changed

+41
-56
lines changed

Models/ImageClassification/WideResNet.swift

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,79 +25,67 @@ public struct BatchNormConv2DBlock: Layer {
2525
public var conv1: Conv2D<Float>
2626
public var norm2: BatchNorm<Float>
2727
public var conv2: Conv2D<Float>
28+
public var shortcut: Conv2D<Float>
29+
let isExpansion: Bool
30+
let dropout: Dropout<Float> = Dropout(probability: 0.3)
2831

2932
public init(
30-
filterShape: (Int, Int, Int, Int),
33+
featureCounts: (Int, Int),
34+
kernelSize: Int = 3,
3135
strides: (Int, Int) = (1, 1),
3236
padding: Padding = .same
3337
) {
34-
self.norm1 = BatchNorm(featureCount: filterShape.2)
35-
self.conv1 = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
36-
self.norm2 = BatchNorm(featureCount: filterShape.3)
37-
self.conv2 = Conv2D(filterShape: filterShape, strides: (1, 1), padding: padding)
38+
self.norm1 = BatchNorm(featureCount: featureCounts.0)
39+
self.conv1 = Conv2D(
40+
filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
41+
strides: strides,
42+
padding: padding)
43+
self.norm2 = BatchNorm(featureCount: featureCounts.1)
44+
self.conv2 = Conv2D(filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.1),
45+
strides: (1, 1),
46+
padding: padding)
47+
self.shortcut = Conv2D(filterShape: (1, 1, featureCounts.0, featureCounts.1),
48+
strides: strides,
49+
padding: padding)
50+
self.isExpansion = featureCounts.1 != featureCounts.0 || strides != (1, 1)
3851
}
3952

4053
@differentiable
4154
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
42-
let firstLayer = conv1(relu(norm1(input)))
43-
return conv2(relu(norm2(firstLayer)))
55+
let preact1 = relu(norm1(input))
56+
var residual = conv1(preact1)
57+
let preact2: Tensor<Float>
58+
let shortcutResult: Tensor<Float>
59+
if isExpansion {
60+
shortcutResult = shortcut(preact1)
61+
preact2 = relu(norm2(residual))
62+
} else {
63+
shortcutResult = input
64+
preact2 = dropout(relu(norm2(residual)))
65+
}
66+
residual = conv2(preact2)
67+
return residual + shortcutResult
4468
}
4569
}
4670

4771
public struct WideResNetBasicBlock: Layer {
4872
public var blocks: [BatchNormConv2DBlock]
49-
public var shortcut: Conv2D<Float>
5073

5174
public init(
5275
featureCounts: (Int, Int),
5376
kernelSize: Int = 3,
5477
depthFactor: Int = 2,
55-
widenFactor: Int = 1,
5678
initialStride: (Int, Int) = (2, 2)
5779
) {
58-
if initialStride == (1, 1) {
59-
self.blocks = [
60-
BatchNormConv2DBlock(
61-
filterShape: (
62-
kernelSize, kernelSize,
63-
featureCounts.0, featureCounts.1 * widenFactor
64-
),
65-
strides: initialStride)
66-
]
67-
self.shortcut = Conv2D(
68-
filterShape: (1, 1, featureCounts.0, featureCounts.1 * widenFactor),
69-
strides: initialStride)
70-
} else {
71-
self.blocks = [
72-
BatchNormConv2DBlock(
73-
filterShape: (
74-
kernelSize, kernelSize,
75-
featureCounts.0 * widenFactor, featureCounts.1 * widenFactor
76-
),
77-
strides: initialStride)
78-
]
79-
self.shortcut = Conv2D(
80-
filterShape: (1, 1, featureCounts.0 * widenFactor, featureCounts.1 * widenFactor),
81-
strides: initialStride)
82-
}
80+
self.blocks = [BatchNormConv2DBlock(featureCounts: featureCounts, strides: initialStride)]
8381
for _ in 1..<depthFactor {
84-
self.blocks += [
85-
BatchNormConv2DBlock(
86-
filterShape: (
87-
kernelSize, kernelSize,
88-
featureCounts.1 * widenFactor, featureCounts.1 * widenFactor
89-
),
90-
strides: (1, 1))
91-
]
92-
}
82+
self.blocks += [BatchNormConv2DBlock(featureCounts: (featureCounts.1, featureCounts.1))]
83+
}
9384
}
9485

9586
@differentiable
9687
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
97-
let blocksReduced = blocks.differentiableReduce(input) { last, layer in
98-
relu(layer(last))
99-
}
100-
return relu(blocksReduced + shortcut(input))
88+
return blocks.differentiableReduce(input) { $1($0) }
10189
}
10290
}
10391

@@ -116,15 +104,12 @@ public struct WideResNet: Layer {
116104
public init(depthFactor: Int = 2, widenFactor: Int = 8) {
117105
self.l1 = Conv2D(filterShape: (3, 3, 3, 16), strides: (1, 1), padding: .same)
118106

119-
l2 = WideResNetBasicBlock(
120-
featureCounts: (16, 16), depthFactor: depthFactor,
121-
widenFactor: widenFactor, initialStride: (1, 1))
122-
l3 = WideResNetBasicBlock(
123-
featureCounts: (16, 32), depthFactor: depthFactor,
124-
widenFactor: widenFactor)
125-
l4 = WideResNetBasicBlock(
126-
featureCounts: (32, 64), depthFactor: depthFactor,
127-
widenFactor: widenFactor)
107+
self.l2 = WideResNetBasicBlock(
108+
featureCounts: (16, 16 * widenFactor), depthFactor: depthFactor, initialStride: (1, 1))
109+
self.l3 = WideResNetBasicBlock(featureCounts: (16 * widenFactor, 32 * widenFactor),
110+
depthFactor: depthFactor)
111+
self.l4 = WideResNetBasicBlock(featureCounts: (32 * widenFactor, 64 * widenFactor),
112+
depthFactor: depthFactor)
128113

129114
self.norm = BatchNorm(featureCount: 64 * widenFactor)
130115
self.avgPool = AvgPool2D(poolSize: (8, 8), strides: (8, 8))

0 commit comments

Comments
 (0)