Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 97a4eac

Browse files
committed
Simplify ResNet block definition
1 parent edbe93c commit 97a4eac

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

Models/LayerInit/AutoSplitMerge.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@ where Layer1.InputShape == Layer2.InputShape, Layer1.InstanceType.Input == Layer
3939
public typealias InputShape = Layer1.InputShape
4040
public typealias OutputShape = OutputShape
4141

42-
public init(layer1: Layer1, layer2: Layer2, mergeOutputShape: @escaping (Layer1.OutputShape, Layer2.OutputShape) -> OutputShape, mergeFn: SplitMergeFunctionWrapper<Layer1.InstanceType.Output, Layer2.InstanceType.Output, CommonOutput>) {
42+
public init(
43+
layer1: Layer1, layer2: Layer2,
44+
mergeOutputShape: @escaping (Layer1.OutputShape, Layer2.OutputShape) -> OutputShape,
45+
mergeFn: @escaping @differentiable (Layer1.InstanceType.Output, Layer2.InstanceType.Output) -> CommonOutput
46+
) {
4347
self.layer1 = layer1
4448
self.layer2 = layer2
4549
self.mergeOutputShape = mergeOutputShape
46-
self.mergeFn = mergeFn
50+
self.mergeFn = SplitMergeFunctionWrapper(mergeFn)
4751
}
4852

4953
public func buildModelWithOutputShape<Prefix>(inputShape: Layer1.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) {

0 commit comments

Comments
 (0)