Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit f51eabf

Browse files
committed
Implement initial API for accessing instance layers
1 parent 1443c87 commit f51eabf

11 files changed

+75
-24
lines changed

Models/LayerInit/AutoBatchNorm.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public struct AutoBatchNorm<Shape, Scalar>: AutoLayer where Scalar: TensorFlowFl
1919
self.epsilon = epsilon
2020
}
2121

22-
public func buildModelWithOutputShape(inputShape: Shape) -> (InstanceType, Shape) {
22+
public func buildModelWithOutputShape<Prefix>(inputShape: Shape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Shape) {
2323
let inputShapeArray: [Int]
2424
if let inputShapeTuple = inputShape as? (Int, Int, Int) {
2525
inputShapeArray = [inputShapeTuple.0, inputShapeTuple.1, inputShapeTuple.2]

Models/LayerInit/AutoConv.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public struct AutoConv2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoin
3737
self.biasInitializer = biasInitializer
3838
}
3939

40-
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, (Int, Int, Int)) {
40+
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, (Int, Int, Int)) {
4141
let outputShape: (Int, Int, Int)
4242
if (padding == .valid) {
4343
outputShape = (

Models/LayerInit/AutoDense.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public struct AutoDense<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint
1313
self.activation = activation
1414
}
1515

16-
public func buildModelWithOutputShape(inputShape: Int) -> (InstanceType, Int) {
16+
public func buildModelWithOutputShape<Prefix>(inputShape: Int, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Int) {
1717
return (Dense<Scalar>(inputSize: inputShape, outputSize: self.outputSize, activation: self.activation), self.outputSize)
1818
}
1919
}

Models/LayerInit/AutoFlatten.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ public struct AutoFlatten<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoi
77

88
public init() {}
99

10-
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, Int) {
10+
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Int) {
1111
return (Flatten<Scalar>(), inputShape.0 * inputShape.1 * inputShape.2)
1212
}
1313
}

Models/LayerInit/AutoFunction.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public struct AutoFunction<Input: Differentiable, Output: Differentiable, InputS
1313
self.fn = fn
1414
}
1515

16-
public func buildModelWithOutputShape(inputShape: InputShape) -> (InstanceType, OutputShape) {
16+
public func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) {
1717
return (Function(fn), fnShape(inputShape))
1818
}
1919
}

Models/LayerInit/AutoLayer.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,19 @@ public protocol AutoLayer {
55
associatedtype InputShape
66
associatedtype OutputShape
77

8-
func buildModelWithOutputShape(inputShape: InputShape) -> (InstanceType, OutputShape)
8+
func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape)
99
}
1010

1111
extension AutoLayer {
1212
public func buildModel(inputShape: InputShape) -> InstanceType {
13-
return self.buildModelWithOutputShape(inputShape: inputShape).0
13+
var keyDict: [AnyAutoLayerKey: Any] = [:]
14+
let (layerInstance, _) = self.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: \InstanceType.self, keyDict: &keyDict)
15+
return layerInstance
16+
}
17+
18+
public func buildModelWithKeys(inputShape: InputShape) -> (InstanceType, [AnyAutoLayerKey: Any]) {
19+
var keyDict: [AnyAutoLayerKey: Any] = [:]
20+
let (layerInstance, _) = self.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: \InstanceType.self, keyDict: &keyDict)
21+
return (layerInstance, keyDict)
1422
}
1523
}

Models/LayerInit/AutoLayerKey.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import TensorFlow
2+
3+
public class AnyAutoLayerKey: Hashable {
4+
public func hash(into hasher: inout Hasher) {
5+
hasher.combine(ObjectIdentifier(self))
6+
}
7+
8+
public static func == (lhs: AnyAutoLayerKey, rhs: AnyAutoLayerKey) -> Bool {
9+
return lhs === rhs
10+
}
11+
}
12+
13+
public class AutoLayerKey<T: Layer>: AnyAutoLayerKey {
14+
public override init() {}
15+
}
16+
17+
extension AutoLayerKey {
18+
public func readFrom<Instance: Layer>(layerInstance: Instance, keyDict: [AnyAutoLayerKey: Any]) -> T {
19+
return layerInstance[keyPath: keyDict[self] as! KeyPath<Instance, T>]
20+
}
21+
}
22+
23+
public struct KeyedAutoLayer<Underlying: AutoLayer>: AutoLayer {
24+
let underlying: Underlying
25+
let key: AutoLayerKey<InstanceType>
26+
27+
public typealias InstanceType = Underlying.InstanceType
28+
public typealias InputShape = Underlying.InputShape
29+
public typealias OutputShape = Underlying.OutputShape
30+
31+
public init(_ underlying: Underlying, key: AutoLayerKey<InstanceType>) {
32+
self.underlying = underlying
33+
self.key = key
34+
}
35+
36+
public func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) {
37+
let (layer, outputShape) = underlying.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar, keyDict: &keyDict)
38+
keyDict[self.key] = keyPathSoFar
39+
return (layer, outputShape)
40+
}
41+
}

Models/LayerInit/AutoModule.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ extension AutoModule {
1111
public typealias InputShape = LayerType.InputShape
1212
public typealias OutputShape = LayerType.OutputShape
1313

14-
public func buildModelWithOutputShape(inputShape: InputShape) -> (InstanceType, OutputShape) {
14+
public func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) {
1515
var selfCopy = self
16-
return selfCopy.initializeLayer.buildModelWithOutputShape(inputShape: inputShape)
16+
return selfCopy.initializeLayer.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar, keyDict: &keyDict)
1717
}
1818
}

Models/LayerInit/AutoPool.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public struct AutoAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingP
1919
self.padding = padding
2020
}
2121

22-
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, (Int, Int, Int)) {
22+
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, (Int, Int, Int)) {
2323
let outputShape: (Int, Int, Int)
2424
if (padding == .valid) {
2525
outputShape = (
@@ -51,7 +51,7 @@ public struct AutoGlobalAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFlo
5151
public init() {
5252
}
5353

54-
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, Int) {
54+
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Int) {
5555
return (GlobalAvgPool2D<Scalar>(), inputShape.2)
5656
}
5757
}
@@ -75,7 +75,7 @@ public struct AutoMaxPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingP
7575
self.padding = padding
7676
}
7777

78-
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, (Int, Int, Int)) {
78+
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, (Int, Int, Int)) {
7979
let outputShape: (Int, Int, Int)
8080
if (padding == .valid) {
8181
outputShape = (

Models/LayerInit/AutoSequenced.swift

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@ where
99
let second: Layer2
1010

1111
public typealias InstanceType = Sequential<Layer1.InstanceType, Layer2.InstanceType>
12+
public typealias InputShape = Layer1.InputShape
13+
public typealias OutputShape = Layer2.OutputShape
1214

1315
public init(first: Layer1, second: Layer2) {
1416
self.first = first
1517
self.second = second
1618
}
1719

18-
public func buildModelWithOutputShape(inputShape: Layer1.InputShape) -> (InstanceType, Layer2.OutputShape) {
19-
let (firstInstance, firstOutputShape) = first.buildModelWithOutputShape(inputShape: inputShape)
20-
let (secondInstance, secondOutputShape) = second.buildModelWithOutputShape(inputShape: firstOutputShape)
20+
public func buildModelWithOutputShape<Prefix>(inputShape: Layer1.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Layer2.OutputShape) {
21+
let (firstInstance, firstOutputShape) = first.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.layer1), keyDict: &keyDict)
22+
let (secondInstance, secondOutputShape) = second.buildModelWithOutputShape(inputShape: firstOutputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.layer2), keyDict: &keyDict)
2123
return (Sequential(firstInstance, secondInstance), secondOutputShape)
2224
}
2325
}
@@ -30,7 +32,7 @@ extension AutoLayer {
3032

3133
public struct AutoSequencedManyInstance<LayerType: Layer>: Layer
3234
where LayerType.Input == LayerType.Output {
33-
var layers: [LayerType]
35+
public var layers: [LayerType]
3436

3537
@differentiable
3638
public func callAsFunction(_ input: LayerType.Input) -> LayerType.Output {
@@ -50,10 +52,10 @@ where
5052
self.layers = layers
5153
}
5254

53-
public func buildModelWithOutputShape(inputShape: LayerType.InputShape) -> (InstanceType, LayerType.OutputShape) {
55+
public func buildModelWithOutputShape<Prefix>(inputShape: LayerType.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, LayerType.OutputShape) {
5456
var lastOutputShape = inputShape
55-
let builtInstances = self.layers.map({ autoLayer -> LayerType.InstanceType in
56-
let (instance, outputShape) = autoLayer.buildModelWithOutputShape(inputShape: lastOutputShape)
57+
let builtInstances = self.layers.enumerated().map({ (idx, autoLayer) -> LayerType.InstanceType in
58+
let (instance, outputShape) = autoLayer.buildModelWithOutputShape(inputShape: lastOutputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.layers[idx]), keyDict: &keyDict)
5759
lastOutputShape = outputShape
5860
return instance
5961
})

Models/LayerInit/AutoSplitMerge.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ public final class SplitMergeFunctionWrapper<Output1: Differentiable, Output2: D
99

1010
public struct SplitMergeInstance<Layer1: Layer, Layer2: Layer, CommonOutput: Differentiable>: Layer
1111
where Layer1.Input == Layer2.Input, Layer1.TangentVector.VectorSpaceScalar == Layer2.TangentVector.VectorSpaceScalar {
12-
var layer1: Layer1
13-
var layer2: Layer2
12+
public var layer1: Layer1
13+
public var layer2: Layer2
1414
@noDerivative let mergeFn: SplitMergeFunctionWrapper<Layer1.Output, Layer2.Output, CommonOutput>
1515

1616
public init(layer1: Layer1, layer2: Layer2, mergeFn: SplitMergeFunctionWrapper<Layer1.Output, Layer2.Output, CommonOutput>) {
@@ -46,9 +46,9 @@ where Layer1.InputShape == Layer2.InputShape, Layer1.InstanceType.Input == Layer
4646
self.mergeFn = mergeFn
4747
}
4848

49-
public func buildModelWithOutputShape(inputShape: Layer1.InputShape) -> (InstanceType, OutputShape) {
50-
let (layer1Built, layer1OutputShape) = layer1.buildModelWithOutputShape(inputShape: inputShape)
51-
let (layer2Built, layer2OutputShape) = layer2.buildModelWithOutputShape(inputShape: inputShape)
49+
public func buildModelWithOutputShape<Prefix>(inputShape: Layer1.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) {
50+
let (layer1Built, layer1OutputShape) = layer1.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.layer1), keyDict: &keyDict)
51+
let (layer2Built, layer2OutputShape) = layer2.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.layer2), keyDict: &keyDict)
5252
return (SplitMergeInstance(layer1: layer1Built, layer2: layer2Built, mergeFn: self.mergeFn), self.mergeOutputShape(layer1OutputShape, layer2OutputShape))
5353
}
5454
}

0 commit comments

Comments
 (0)