Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 523933f

Browse files
authored
Adding shape-based inference tests for all image classification models (#198)
* Implemented inference tests with random tensors and starting weights for all classification models. * Made sure tests ran on Linux, reshaped output of SqueezeNet to match other classification models. * WideResNet is expressed in terms of CIFAR10, so altered the inputs and outputs appropriately. * Wrapping the reshaping line in SqueezeNet. * Reset ownership. * Reworked TensorShape initializers to use array literals. * Minor formatting tweak to SqueezeNet.
1 parent 21c694d commit 523933f

File tree

5 files changed

+185
-0
lines changed

5 files changed

+185
-0
lines changed

Models/ImageClassification/SqueezeNet.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ public struct SqueezeNet: Layer {
111111
let fired1 = convolved1.sequenced(through: fire2, fire3, fire4, maxPool4, fire5, fire6)
112112
let fired2 = fired1.sequenced(through: fire7, fire8, maxPool8, fire9)
113113
let convolved2 = fired2.sequenced(through: dropout, conv10, avgPool10)
114+
.reshaped(to: [input.shape[0], conv10.filter.shape[3]])
114115
return convolved2
115116
}
116117
}

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ let package = Package(
4242
name: "MiniGoDemo", dependencies: ["MiniGo"], path: "MiniGo",
4343
sources: ["main.swift"]),
4444
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
45+
.testTarget(name: "ImageClassificationTests", dependencies: ["ImageClassificationModels"]),
4546
.target(name: "Transformer", path: "Transformer"),
4647
.target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"),
4748
]
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
import XCTest
17+
18+
@testable import ImageClassificationModels
19+
20+
final class ImageClassificationInferenceTests: XCTestCase {
21+
override class func setUp() {
22+
Context.local.learningPhase = .inference
23+
}
24+
25+
func testLeNet() {
26+
let leNet = LeNet()
27+
let input = Tensor<Float>(
28+
randomNormal: [1, 28, 28, 1], mean: Tensor<Float>(0.5),
29+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
30+
let result = leNet(input)
31+
XCTAssertEqual(result.shape, [1, 10])
32+
}
33+
34+
func testResNet() {
35+
let inputCIFAR = Tensor<Float>(
36+
randomNormal: [1, 32, 32, 3], mean: Tensor<Float>(0.5),
37+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
38+
let resNet18CIFAR = ResNetBasic(inputKind: .resNet18, dataKind: .cifar)
39+
let resNet18CIFARResult = resNet18CIFAR(inputCIFAR)
40+
XCTAssertEqual(resNet18CIFARResult.shape, [1, 10])
41+
42+
let resNet34CIFAR = ResNetBasic(inputKind: .resNet34, dataKind: .cifar)
43+
let resNet34CIFARResult = resNet34CIFAR(inputCIFAR)
44+
XCTAssertEqual(resNet34CIFARResult.shape, [1, 10])
45+
46+
let resNet50CIFAR = ResNet(inputKind: .resNet50, dataKind: .cifar)
47+
let resNet50CIFARResult = resNet50CIFAR(inputCIFAR)
48+
XCTAssertEqual(resNet50CIFARResult.shape, [1, 10])
49+
50+
let resNet101CIFAR = ResNet(inputKind: .resNet101, dataKind: .cifar)
51+
let resNet101CIFARResult = resNet101CIFAR(inputCIFAR)
52+
XCTAssertEqual(resNet101CIFARResult.shape, [1, 10])
53+
54+
let resNet152CIFAR = ResNet(inputKind: .resNet152, dataKind: .cifar)
55+
let resNet152CIFARResult = resNet152CIFAR(inputCIFAR)
56+
XCTAssertEqual(resNet152CIFARResult.shape, [1, 10])
57+
58+
let inputImageNet = Tensor<Float>(
59+
randomNormal: [1, 224, 224, 3], mean: Tensor<Float>(0.5),
60+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
61+
let resNet18ImageNet = ResNetBasic(inputKind: .resNet18, dataKind: .imagenet)
62+
let resNet18ImageNetResult = resNet18ImageNet(inputImageNet)
63+
XCTAssertEqual(resNet18ImageNetResult.shape, [1, 1000])
64+
65+
let resNet34ImageNet = ResNetBasic(inputKind: .resNet34, dataKind: .imagenet)
66+
let resNet34ImageNetResult = resNet34ImageNet(inputImageNet)
67+
XCTAssertEqual(resNet34ImageNetResult.shape, [1, 1000])
68+
69+
let resNet50ImageNet = ResNet(inputKind: .resNet50, dataKind: .imagenet)
70+
let resNet50ImageNetResult = resNet50ImageNet(inputImageNet)
71+
XCTAssertEqual(resNet50ImageNetResult.shape, [1, 1000])
72+
73+
let resNet101ImageNet = ResNet(inputKind: .resNet101, dataKind: .imagenet)
74+
let resNet101ImageNetResult = resNet101ImageNet(inputImageNet)
75+
XCTAssertEqual(resNet101ImageNetResult.shape, [1, 1000])
76+
77+
let resNet152ImageNet = ResNet(inputKind: .resNet152, dataKind: .imagenet)
78+
let resNet152ImageNetResult = resNet152ImageNet(inputImageNet)
79+
XCTAssertEqual(resNet152ImageNetResult.shape, [1, 1000])
80+
}
81+
82+
func testResNetV2() {
83+
let input = Tensor<Float>(
84+
randomNormal: [1, 224, 224, 3], mean: Tensor<Float>(0.5),
85+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
86+
let resNet18ImageNet = PreActivatedResNet18(imageSize: 224, classCount: 1000)
87+
let resNet18ImageNetResult = resNet18ImageNet(input)
88+
XCTAssertEqual(resNet18ImageNetResult.shape, [1, 1000])
89+
90+
let resNet34ImageNet = PreActivatedResNet34(imageSize: 224, classCount: 1000)
91+
let resNet34ImageNetResult = resNet34ImageNet(input)
92+
XCTAssertEqual(resNet34ImageNetResult.shape, [1, 1000])
93+
}
94+
95+
func testSqueezeNet() {
96+
let input = Tensor<Float>(
97+
randomNormal: [1, 224, 224, 3], mean: Tensor<Float>(0.5),
98+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
99+
let squeezeNet = SqueezeNet(classCount: 1000)
100+
let squeezeNetResult = squeezeNet(input)
101+
XCTAssertEqual(squeezeNetResult.shape, [1, 1000])
102+
}
103+
104+
func testWideResNet() {
105+
let input = Tensor<Float>(
106+
randomNormal: [1, 32, 32, 3], mean: Tensor<Float>(0.5),
107+
standardDeviation: Tensor<Float>(0.1), seed: (0xffeffe, 0xfffe))
108+
let wideResNet16 = WideResNet(kind: .wideResNet16)
109+
let wideResNet16Result = wideResNet16(input)
110+
XCTAssertEqual(wideResNet16Result.shape, [1, 10])
111+
112+
let wideResNet16k10 = WideResNet(kind: .wideResNet16k10)
113+
let wideResNet16k10Result = wideResNet16k10(input)
114+
XCTAssertEqual(wideResNet16k10Result.shape, [1, 10])
115+
116+
let wideResNet22 = WideResNet(kind: .wideResNet22)
117+
let wideResNet22Result = wideResNet22(input)
118+
XCTAssertEqual(wideResNet22Result.shape, [1, 10])
119+
120+
let wideResNet22k10 = WideResNet(kind: .wideResNet22k10)
121+
let wideResNet22k10Result = wideResNet22k10(input)
122+
XCTAssertEqual(wideResNet22k10Result.shape, [1, 10])
123+
124+
let wideResNet28 = WideResNet(kind: .wideResNet28)
125+
let wideResNet28Result = wideResNet28(input)
126+
XCTAssertEqual(wideResNet28Result.shape, [1, 10])
127+
128+
let wideResNet28k12 = WideResNet(kind: .wideResNet28k12)
129+
let wideResNet28k12Result = wideResNet28k12(input)
130+
XCTAssertEqual(wideResNet28k12Result.shape, [1, 10])
131+
132+
let wideResNet40k1 = WideResNet(kind: .wideResNet40k1)
133+
let wideResNet40k1Result = wideResNet40k1(input)
134+
XCTAssertEqual(wideResNet40k1Result.shape, [1, 10])
135+
136+
let wideResNet40k2 = WideResNet(kind: .wideResNet40k2)
137+
let wideResNet40k2Result = wideResNet40k2(input)
138+
XCTAssertEqual(wideResNet40k2Result.shape, [1, 10])
139+
140+
let wideResNet40k4 = WideResNet(kind: .wideResNet40k4)
141+
let wideResNet40k4Result = wideResNet40k4(input)
142+
XCTAssertEqual(wideResNet40k4Result.shape, [1, 10])
143+
144+
let wideResNet40k8 = WideResNet(kind: .wideResNet40k8)
145+
let wideResNet40k8Result = wideResNet40k8(input)
146+
XCTAssertEqual(wideResNet40k8Result.shape, [1, 10])
147+
}
148+
}
149+
150+
extension ImageClassificationInferenceTests {
151+
static var allTests = [
152+
("testLeNet", testLeNet),
153+
("testResNet", testResNet),
154+
("testResNetV2", testResNetV2),
155+
("testSqueezeNet", testSqueezeNet),
156+
("testWideResNet", testWideResNet),
157+
]
158+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
#if !os(macOS)
18+
public func allTests() -> [XCTestCaseEntry] {
19+
return [
20+
testCase(ImageClassificationInferenceTests.allTests),
21+
]
22+
}
23+
#endif

Tests/LinuxMain.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import XCTest
22

3+
import ImageClassificationTests
34
import MiniGoTests
45

56
var tests = [XCTestCaseEntry]()
7+
tests += ImageClassificationTests.allTests()
68
tests += MiniGoTests.allTests()
79
XCTMain(tests)

0 commit comments

Comments
 (0)