Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 3b1f373

Browse files
Andr0id100BradLarson
authored andcommitted
DenseNet Implementaion (#213)
* Create DenseNet.swift Implementation for the DenseNet Architecture * Added test for DenseNet * Fixed Typo * Refractored the code * Updated ConvPair description Co-Authored-By: Richard Wei <[email protected]> * Fixed comment Co-Authored-By: Richard Wei <[email protected]> * Fixed typos, and improved code quality * Removed extra whitespace and '\' * Used 3 '/' for documentation and removed return * Renamed DenseNet to DenseNet121 * Renamed tests * Updated names * Renamed all DenseNet identifiers to DenseNet121
1 parent 3df2875 commit 3b1f373

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
// Original Paper:
18+
// Densely Connected Convolutional Networks
19+
// Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger
20+
// https://arxiv.org/pdf/1608.06993.pdf
21+
22+
public struct DenseNet121: Layer {
23+
public var conv = Conv(
24+
filterSize: 7,
25+
stride: 2,
26+
inputFilterCount: 3,
27+
outputFilterCount: 64
28+
)
29+
public var maxpool = MaxPool2D<Float>(
30+
poolSize: (3, 3),
31+
strides: (2, 2),
32+
padding: .same
33+
)
34+
public var denseBlock1 = DenseBlock(repetitionCount: 6, inputFilterCount: 64)
35+
public var transitionLayer1 = TransitionLayer(inputFilterCount: 256)
36+
public var denseBlock2 = DenseBlock(repetitionCount: 12, inputFilterCount: 128)
37+
public var transitionLayer2 = TransitionLayer(inputFilterCount: 512)
38+
public var denseBlock3 = DenseBlock(repetitionCount: 24, inputFilterCount: 256)
39+
public var transitionLayer3 = TransitionLayer(inputFilterCount: 1024)
40+
public var denseBlock4 = DenseBlock(repetitionCount: 16, inputFilterCount: 512)
41+
public var globalAvgPool = GlobalAvgPool2D<Float>()
42+
public var dense: Dense<Float>
43+
44+
public init(classCount: Int) {
45+
dense = Dense(inputSize: 1024, outputSize: classCount, activation: softmax)
46+
}
47+
48+
@differentiable
49+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
50+
let inputLayer = input.sequenced(through: conv, maxpool)
51+
let level1 = inputLayer.sequenced(through: denseBlock1, transitionLayer1)
52+
let level2 = level1.sequenced(through: denseBlock2, transitionLayer2)
53+
let level3 = level2.sequenced(through: denseBlock3, transitionLayer3)
54+
let output = level3.sequenced(through: denseBlock4, globalAvgPool, dense)
55+
return output
56+
}
57+
}
58+
59+
extension DenseNet121 {
60+
public struct Conv: Layer {
61+
public var batchNorm: BatchNorm<Float>
62+
public var conv: Conv2D<Float>
63+
64+
public init(
65+
filterSize: Int,
66+
stride: Int = 1,
67+
inputFilterCount: Int,
68+
outputFilterCount: Int
69+
) {
70+
batchNorm = BatchNorm(featureCount: inputFilterCount)
71+
conv = Conv2D(
72+
filterShape: (filterSize, filterSize, inputFilterCount, outputFilterCount),
73+
strides: (stride, stride),
74+
padding: .same
75+
)
76+
}
77+
78+
@differentiable
79+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
80+
conv(relu(batchNorm(input)))
81+
}
82+
}
83+
84+
/// A pair of a 1x1 `Conv` layer and a 3x3 `Conv` layer.
85+
public struct ConvPair: Layer {
86+
public var conv1x1: Conv
87+
public var conv3x3: Conv
88+
89+
public init(inputFilterCount: Int, growthRate: Int) {
90+
conv1x1 = Conv(
91+
filterSize: 1,
92+
inputFilterCount: inputFilterCount,
93+
outputFilterCount: inputFilterCount * 2
94+
)
95+
conv3x3 = Conv(
96+
filterSize: 3,
97+
inputFilterCount: inputFilterCount * 2,
98+
outputFilterCount: growthRate
99+
)
100+
}
101+
102+
@differentiable
103+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
104+
let conv1Output = conv1x1(input)
105+
let conv3Output = conv3x3(conv1Output)
106+
return conv3Output.concatenated(with: input, alongAxis: -1)
107+
}
108+
}
109+
110+
public struct DenseBlock: Layer {
111+
public var pairs: [ConvPair] = []
112+
113+
public init(repetitionCount: Int, growthRate: Int = 32, inputFilterCount: Int) {
114+
for i in 0..<repetitionCount {
115+
let filterCount = inputFilterCount + i * growthRate
116+
pairs.append(ConvPair(inputFilterCount: filterCount, growthRate: growthRate))
117+
}
118+
}
119+
120+
@differentiable
121+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
122+
pairs.differentiableReduce(input) { last, layer in
123+
layer(last)
124+
}
125+
}
126+
}
127+
128+
public struct TransitionLayer: Layer {
129+
public var conv: Conv
130+
public var pool: AvgPool2D<Float>
131+
132+
public init(inputFilterCount: Int) {
133+
conv = Conv(
134+
filterSize: 1,
135+
inputFilterCount: inputFilterCount,
136+
outputFilterCount: inputFilterCount / 2
137+
)
138+
pool = AvgPool2D(poolSize: (2, 2), strides: (2, 2), padding: .same)
139+
}
140+
141+
@differentiable
142+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
143+
input.sequenced(through: conv, pool)
144+
}
145+
}
146+
}

Tests/ImageClassificationTests/Inference.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ final class ImageClassificationInferenceTests: XCTestCase {
2121
override class func setUp() {
2222
Context.local.learningPhase = .inference
2323
}
24+
25+
func testDenseNet121() {
26+
let input = Tensor<Float>(
27+
randomNormal: [1, 224, 224, 3], mean: Tensor<Float>(0.5),
28+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
29+
let denseNet121 = DenseNet121(classCount: 1000)
30+
let denseNet121Result = denseNet121(input)
31+
XCTAssertEqual(denseNet121Result.shape, [1, 1000])
32+
}
2433

2534
func testLeNet() {
2635
let leNet = LeNet()
@@ -158,6 +167,7 @@ final class ImageClassificationInferenceTests: XCTestCase {
158167

159168
extension ImageClassificationInferenceTests {
160169
static var allTests = [
170+
("testDenseNet121", testDenseNet121),
161171
("testLeNet", testLeNet),
162172
("testResNet", testResNet),
163173
("testResNetV2", testResNetV2),

0 commit comments

Comments
 (0)