Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit e756597

Browse files
committed
Simplify API for getting layers by key
1 parent 6ecfd51 commit e756597

File tree

5 files changed

+25
-12
lines changed

5 files changed

+25
-12
lines changed

Examples/LeNet-MNIST/main.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ let denseDef = AutoConv2D<Float>(filterShape: (5, 5), outputChannels: 6, padding
4141
.then(AutoDense(outputSize: 84, activation: relu))
4242
.then(KeyedAutoLayer(AutoDense(outputSize: 10), key: lastDenseKey))
4343

44-
var (classifier, keys) = denseDef.buildModelWithKeys(inputShape: (28, 28, 1))
44+
var classifier = denseDef.buildModel(inputShape: (28, 28, 1))
4545
classifier.move(to: device)
4646

4747
var optimizer = SGD(for: classifier, learningRate: 0.1)
@@ -123,5 +123,5 @@ for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
123123
(\(String(format: "%.3f", testStats.accuracy))%)
124124
""")
125125

126-
print(lastDenseKey.readFrom(layerInstance: classifier, keyDict: keys).bias)
126+
print(classifier[lastDenseKey].bias)
127127
}

Models/ImageClassification/ResNet.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public func autoResNet(
125125

126126
/// An implementation of the ResNet v1 and v1.5 architectures, at various depths.
127127
public struct ResNet: Layer {
128-
public var underlying: AutoResNet.InstanceType
128+
public var underlying: BuiltAutoLayer<AutoResNet.InstanceType>
129129

130130
/// Initializes a new ResNet v1 or v1.5 network model.
131131
///

Models/LayerInit/AutoLayer.swift

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ public protocol AutoLayer {
99
}
1010

1111
extension AutoLayer {
12-
public func buildModel(inputShape: InputShape) -> InstanceType {
12+
public func buildModel(inputShape: InputShape) -> BuiltAutoLayer<InstanceType> {
1313
var keyDict: [AnyAutoLayerKey: Any] = [:]
1414
let (layerInstance, _) = self.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: \InstanceType.self, keyDict: &keyDict)
15-
return layerInstance
15+
return BuiltAutoLayer(layer: layerInstance, keyMapping: keyDict)
1616
}
1717

1818
public func buildModelWithKeys(inputShape: InputShape) -> (InstanceType, [AnyAutoLayerKey: Any]) {
@@ -21,3 +21,22 @@ extension AutoLayer {
2121
return (layerInstance, keyDict)
2222
}
2323
}
24+
25+
public struct BuiltAutoLayer<InstanceType: Layer>: Layer {
26+
public var layer: InstanceType
27+
@noDerivative let keyMapping: [AnyAutoLayerKey: Any]
28+
29+
public init(layer: InstanceType, keyMapping: [AnyAutoLayerKey: Any]) {
30+
self.layer = layer
31+
self.keyMapping = keyMapping
32+
}
33+
34+
@differentiable
35+
public func callAsFunction(_ input: InstanceType.Input) -> InstanceType.Output {
36+
return layer(input)
37+
}
38+
39+
public subscript<T>(index: AutoLayerKey<T>) -> T {
40+
return self.layer[keyPath: self.keyMapping[index] as! KeyPath<InstanceType, T>]
41+
}
42+
}

Models/LayerInit/AutoLayerKey.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ public class AutoLayerKey<T: Layer>: AnyAutoLayerKey {
1414
public override init() {}
1515
}
1616

17-
extension AutoLayerKey {
18-
public func readFrom<Instance: Layer>(layerInstance: Instance, keyDict: [AnyAutoLayerKey: Any]) -> T {
19-
return layerInstance[keyPath: keyDict[self] as! KeyPath<Instance, T>]
20-
}
21-
}
22-
2317
public struct KeyedAutoLayer<Underlying: AutoLayer>: AutoLayer {
2418
let underlying: Underlying
2519
let key: AutoLayerKey<InstanceType>

Models/LayerInit/AutoModule.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ public protocol AutoModule: AutoLayer {
66
var initializeLayer: LayerType { mutating get }
77
}
88

9-
extension AutoModule {
9+
extension AutoModule {
1010
public typealias InstanceType = LayerType.InstanceType
1111
public typealias InputShape = LayerType.InputShape
1212
public typealias OutputShape = LayerType.OutputShape

0 commit comments

Comments
 (0)