@@ -25,15 +25,15 @@ import LayerInit
25
25
// The structure of this implementation was inspired by the Flax ResNet example:
26
26
// https://github.com/google/flax/blob/master/examples/imagenet/models.py
27
27
28
- public typealias AutoConvBN = AutoSequencedDefinition < AutoBatchNorm < ( Int , Int , Int ) , Float > , AutoConv2D < Float > >
28
+ public typealias AutoConvBN = AutoSequenced < AutoBatchNorm < ( Int , Int , Int ) , Float > , AutoConv2D < Float > >
29
29
public func autoConvBN( filterShape: ( Int , Int ) , outputChannels: Int , strides: ( Int , Int ) = ( 1 , 1 ) , padding: Padding = . valid) -> AutoConvBN {
30
30
return AutoBatchNorm < ( Int , Int , Int ) , Float > ( momentum: 0.9 , epsilon: 1e-5 )
31
31
. then ( AutoConv2D < Float > ( filterShape: filterShape, outputChannels: outputChannels, strides: strides, padding: padding, useBias: false ) )
32
32
}
33
33
34
34
// TODO(shadaj): OH NO
35
- public typealias ConvPlusResidual = AutoSplitMerge < AutoSequencedMany < AutoConvBN > , AutoSequencedDefinition < AutoSequencedMany < AutoSequencedDefinition < AutoConvBN , AutoFunction < Tensor < Float > , Tensor < Float > , AutoConv2D < Float > . OutputShape , AutoConv2D < Float > . OutputShape > > > , AutoConvBN > , Tensor < Float > , AutoBatchNorm < ( Int , Int , Int ) , Float > . InputShape >
36
- public typealias AutoResidualBlock = AutoSequencedDefinition < ConvPlusResidual , AutoFunction < Tensor < Float > , Tensor < Float > , ( Int , Int , Int ) , ( Int , Int , Int ) > >
35
+ public typealias ConvPlusResidual = AutoSplitMerge < AutoSequencedMany < AutoConvBN > , AutoSequenced < AutoSequencedMany < AutoSequenced < AutoConvBN , AutoFunction < Tensor < Float > , Tensor < Float > , ( Int , Int , Int ) , ( Int , Int , Int ) > > > , AutoConvBN > , Tensor < Float > , ( Int , Int , Int ) >
36
+ public typealias AutoResidualBlock = AutoSequenced < ConvPlusResidual , AutoFunction < Tensor < Float > , Tensor < Float > , ( Int , Int , Int ) , ( Int , Int , Int ) > >
37
37
public func autoResidualBlock( inputFilters: Int , filters: Int , strides: ( Int , Int ) , useLaterStride: Bool , isBasic: Bool ) -> AutoResidualBlock {
38
38
let outFilters = filters * ( isBasic ? 1 : 4 )
39
39
let needsProjection = ( inputFilters != outFilters) || ( strides. 0 != 1 )
@@ -79,7 +79,7 @@ public func autoResidualBlock(inputFilters: Int, filters: Int, strides: (Int, In
79
79
return convPlusResidual. then ( AutoFunction ( fnShape: { $0 } , fn: { ( prev: Tensor < Float > ) in relu ( prev) } ) )
80
80
}
81
81
82
- public typealias AutoResNet = AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoConvBN , AutoFunction < Tensor < Float > , Tensor < Float > , AutoConv2D < Float > . OutputShape , AutoMaxPool2D < Float > . InputShape > > , AutoMaxPool2D < Float > > , AutoSequencedMany < AutoResidualBlock > > , AutoGlobalAvgPool2D < Float > > , AutoDense < Float > >
82
+ public typealias AutoResNet = AutoSequenced < AutoSequenced < AutoSequenced < AutoSequenced < AutoSequenced < AutoConvBN , AutoFunction < Tensor < Float > , Tensor < Float > , AutoConv2D < Float > . OutputShape , AutoMaxPool2D < Float > . InputShape > > , AutoMaxPool2D < Float > > , AutoSequencedMany < AutoResidualBlock > > , AutoGlobalAvgPool2D < Float > > , AutoDense < Float > >
83
83
public func autoResNet(
84
84
classCount: Int , depth: ResNet . Depth , downsamplingInFirstStage: Bool = true ,
85
85
useLaterStride: Bool = true
0 commit comments