Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit b963b60

Browse files
committed
Simplify ResNet block definition
1 parent e756597 commit b963b60

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

Models/ImageClassification/ResNet.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ public func autoResidualBlock(inputFilters: Int, filters: Int, strides: (Int, In
6666
}
6767

6868
let earlyConvsWithRelu = earlyConvs.map({ (conv) in
69-
conv.then(AutoFunction(fnShape: { $0 }, fn: { (prev: Tensor<Float>) in relu(prev) }))
69+
conv.then(AutoFunction(fnShape: { $0 }, fn: { relu($0) }))
7070
})
7171

7272
let lastConvResult = AutoSequencedMany(layers: earlyConvsWithRelu).then(lastConv)
7373

7474
let convPlusResidual = AutoSplitMerge(
7575
layer1: residual,
7676
layer2: lastConvResult,
77-
mergeOutputShape: { (l1, l2) in l1 }, mergeFn: SplitMergeFunctionWrapper({ $0 + $1 }))
77+
mergeOutputShape: { (l1, l2) in l1 }, mergeFn: { $0 + $1 })
7878

79-
return convPlusResidual.then(AutoFunction(fnShape: { $0 }, fn: { (prev: Tensor<Float>) in relu(prev) }))
79+
return convPlusResidual.then(AutoFunction(fnShape: { $0 }, fn: { relu($0) }))
8080
}
8181

8282
public typealias AutoResNet = AutoSequenced<AutoSequenced<AutoSequenced<AutoSequenced<AutoSequenced<AutoConvBN, AutoFunction<Tensor<Float>, Tensor<Float>, AutoConv2D<Float>.OutputShape, AutoMaxPool2D<Float>.InputShape>>, AutoMaxPool2D<Float>>, AutoSequencedMany<AutoResidualBlock>>, AutoGlobalAvgPool2D<Float>>, AutoDense<Float>>

Models/LayerInit/AutoSplitMerge.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@ where Layer1.InputShape == Layer2.InputShape, Layer1.InstanceType.Input == Layer
3939
public typealias InputShape = Layer1.InputShape
4040
public typealias OutputShape = OutputShape
4141

42-
public init(layer1: Layer1, layer2: Layer2, mergeOutputShape: @escaping (Layer1.OutputShape, Layer2.OutputShape) -> OutputShape, mergeFn: SplitMergeFunctionWrapper<Layer1.InstanceType.Output, Layer2.InstanceType.Output, CommonOutput>) {
42+
public init(
43+
layer1: Layer1, layer2: Layer2,
44+
mergeOutputShape: @escaping (Layer1.OutputShape, Layer2.OutputShape) -> OutputShape,
45+
mergeFn: @escaping @differentiable (Layer1.InstanceType.Output, Layer2.InstanceType.Output) -> CommonOutput
46+
) {
4347
self.layer1 = layer1
4448
self.layer2 = layer2
4549
self.mergeOutputShape = mergeOutputShape
46-
self.mergeFn = mergeFn
50+
self.mergeFn = SplitMergeFunctionWrapper(mergeFn)
4751
}
4852

4953
public func buildModelWithOutputShape<Prefix>(inputShape: Layer1.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) {

0 commit comments

Comments
 (0)