Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 8a15708

Browse files
committed
Simplify API for getting layers by key
1 parent f51eabf commit 8a15708

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

Models/LayerInit/AutoLayer.swift

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ public protocol AutoLayer {
99
}
1010

1111
extension AutoLayer {
12-
public func buildModel(inputShape: InputShape) -> InstanceType {
12+
public func buildModel(inputShape: InputShape) -> BuiltAutoLayer<InstanceType> {
1313
var keyDict: [AnyAutoLayerKey: Any] = [:]
1414
let (layerInstance, _) = self.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: \InstanceType.self, keyDict: &keyDict)
15-
return layerInstance
15+
return BuiltAutoLayer(layer: layerInstance, keyMapping: keyDict)
1616
}
1717

1818
public func buildModelWithKeys(inputShape: InputShape) -> (InstanceType, [AnyAutoLayerKey: Any]) {
@@ -21,3 +21,22 @@ extension AutoLayer {
2121
return (layerInstance, keyDict)
2222
}
2323
}
24+
25+
public struct BuiltAutoLayer<InstanceType: Layer>: Layer {
26+
public var layer: InstanceType
27+
@noDerivative let keyMapping: [AnyAutoLayerKey: Any]
28+
29+
public init(layer: InstanceType, keyMapping: [AnyAutoLayerKey: Any]) {
30+
self.layer = layer
31+
self.keyMapping = keyMapping
32+
}
33+
34+
@differentiable
35+
public func callAsFunction(_ input: InstanceType.Input) -> InstanceType.Output {
36+
return layer(input)
37+
}
38+
39+
public subscript<T>(index: AutoLayerKey<T>) -> T {
40+
return self.layer[keyPath: self.keyMapping[index] as! KeyPath<InstanceType, T>]
41+
}
42+
}

Models/LayerInit/AutoLayerKey.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ public class AutoLayerKey<T: Layer>: AnyAutoLayerKey {
1414
public override init() {}
1515
}
1616

17-
extension AutoLayerKey {
18-
public func readFrom<Instance: Layer>(layerInstance: Instance, keyDict: [AnyAutoLayerKey: Any]) -> T {
19-
return layerInstance[keyPath: keyDict[self] as! KeyPath<Instance, T>]
20-
}
21-
}
22-
2317
public struct KeyedAutoLayer<Underlying: AutoLayer>: AutoLayer {
2418
let underlying: Underlying
2519
let key: AutoLayerKey<InstanceType>

Models/LayerInit/AutoModule.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ public protocol AutoModule: AutoLayer {
66
var initializeLayer: LayerType { mutating get }
77
}
88

9-
extension AutoModule {
9+
extension AutoModule {
1010
public typealias InstanceType = LayerType.InstanceType
1111
public typealias InputShape = LayerType.InputShape
1212
public typealias OutputShape = LayerType.OutputShape

0 commit comments

Comments
 (0)