Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 4a9719a

Browse files
committed
Explore LayerModule protocol to enable better type inference
1 parent 2db6dd5 commit 4a9719a

File tree

3 files changed

+56
-33
lines changed

3 files changed

+56
-33
lines changed

Examples/VGG-Imagewoof/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import TensorFlow
1919
let batchSize = 32
2020

2121
let dataset = Imagewoof(batchSize: batchSize, inputSize: .full, outputSize: 224)
22-
var model = makeVGG16(classCount: 10).buildModel(inputShape: (224, 224, 3))
22+
var model = AutoVGG16(classCount: 10).buildModel(inputShape: (224, 224, 3))
2323
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9, decay: 0.0005)
2424

2525
print("Starting training...")

Models/ImageClassification/VGG.swift

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,46 @@ import LayerInit
2020
// Karen Simonyan, Andrew Zisserman
2121
// https://arxiv.org/abs/1409.1556
2222

23-
public typealias AutoVGGBlock = AutoSequenced<AutoSequencedMany<AutoConv2D<Float>>, AutoMaxPool2D<Float>>
24-
func makeVGGBlock(featureCounts: (Int, Int, Int, Int), blockCount: Int) -> AutoVGGBlock {
25-
var blocks: [AutoConv2D<Float>] = [
26-
AutoConv2D<Float>(filterShape: (3, 3), outputChannels: featureCounts.1,
27-
padding: .same,
28-
activation: relu)]
29-
for _ in 1..<blockCount {
30-
blocks += [AutoConv2D(filterShape: (3, 3), outputChannels: featureCounts.3,
31-
padding: .same,
32-
activation: relu)]
33-
}
23+
public struct AutoVGGBlock: AutoModule {
24+
let featureCounts: (Int, Int, Int, Int)
25+
let blockCount: Int
3426

35-
return AutoSequencedMany(layers: blocks)
36-
.then(AutoMaxPool2D(poolSize: (2, 2), strides: (2, 2)))
27+
public typealias LayerType = AutoSequenced<AutoSequencedMany<AutoConv2D<Float>>, AutoMaxPool2D<Float>>
28+
public lazy var initializeLayer: LayerType = {
29+
var blocks: [AutoConv2D<Float>] = [
30+
AutoConv2D<Float>(filterShape: (3, 3), outputChannels: featureCounts.1,
31+
padding: .same,
32+
activation: relu)]
33+
34+
for _ in 1..<blockCount {
35+
blocks += [AutoConv2D(filterShape: (3, 3), outputChannels: featureCounts.3,
36+
padding: .same,
37+
activation: relu)]
38+
}
39+
40+
return AutoSequencedMany(layers: blocks)
41+
.then(AutoMaxPool2D(poolSize: (2, 2), strides: (2, 2)))
42+
}()
3743
}
3844

39-
// TODO(shadaj): oh no
40-
public typealias AutoVGG16Backbone = AutoSequenced<AutoSequenced<AutoSequenced<AutoSequenced<AutoVGGBlock, AutoVGGBlock>, AutoVGGBlock>, AutoVGGBlock>, AutoVGGBlock>
41-
public typealias AutoVGG16 = AutoSequenced<AutoSequenced<AutoSequenced<AutoSequenced<AutoVGG16Backbone, AutoFlatten<Float>>, AutoDense<Float>>, AutoDense<Float>>, AutoDense<Float>>
42-
43-
public func makeVGG16(classCount: Int = 1000) -> AutoVGG16 {
44-
let layer1 = makeVGGBlock(featureCounts: (3, 64, 64, 64), blockCount: 2)
45-
let layer2 = makeVGGBlock(featureCounts: (64, 128, 128, 128), blockCount: 2)
46-
let layer3 = makeVGGBlock(featureCounts: (128, 256, 256, 256), blockCount: 3)
47-
let layer4 = makeVGGBlock(featureCounts: (256, 512, 512, 512), blockCount: 3)
48-
let layer5 = makeVGGBlock(featureCounts: (512, 512, 512, 512), blockCount: 3)
49-
50-
let flatten = AutoFlatten<Float>()
51-
let dense1 = AutoDense<Float>(outputSize: 4096, activation: relu)
52-
let dense2 = AutoDense<Float>(outputSize: 4096, activation: relu)
53-
let output = AutoDense<Float>(outputSize: classCount)
54-
55-
let backbone = layer1.then(layer2).then(layer3).then(layer4).then(layer5)
56-
let fullModel = backbone.then(flatten).then(dense1).then(dense2).then(output)
57-
return fullModel
45+
public struct AutoVGG16: AutoModule {
46+
let classCount: Int
47+
48+
public init(classCount: Int = 1000) {
49+
self.classCount = classCount
50+
}
51+
52+
public lazy var initializeLayer = {
53+
return AutoVGGBlock(featureCounts: (3, 64, 64, 64), blockCount: 2)
54+
.then(AutoVGGBlock(featureCounts: (64, 128, 128, 128), blockCount: 2))
55+
.then(AutoVGGBlock(featureCounts: (128, 256, 256, 256), blockCount: 3))
56+
.then(AutoVGGBlock(featureCounts: (256, 512, 512, 512), blockCount: 3))
57+
.then(AutoVGGBlock(featureCounts: (512, 512, 512, 512), blockCount: 3))
58+
.then(AutoFlatten())
59+
.then(AutoDense(outputSize: 4096, activation: relu))
60+
.then(AutoDense(outputSize: 4096, activation: relu))
61+
.then(AutoDense(outputSize: classCount))
62+
}()
5863
}
5964

6065
public struct VGGBlock: Layer {

Models/LayerInit/AutoModule.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import TensorFlow
2+
3+
public protocol AutoModule: AutoLayer {
4+
associatedtype LayerType: AutoLayer
5+
6+
var initializeLayer: LayerType { mutating get }
7+
}
8+
9+
extension AutoModule {
10+
public typealias InstanceType = LayerType.InstanceType
11+
public typealias InputShape = LayerType.InputShape
12+
public typealias OutputShape = LayerType.OutputShape
13+
14+
public func buildModelWithOutputShape(inputShape: InputShape) -> (InstanceType, OutputShape) {
15+
var selfCopy = self
16+
return selfCopy.initializeLayer.buildModelWithOutputShape(inputShape: inputShape)
17+
}
18+
}

0 commit comments

Comments
 (0)