|
| 1 | +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +import TensorFlow |
| 16 | + |
| 17 | +// Original Paper: |
| 18 | +// Densely Connected Convolutional Networks |
| 19 | +// Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger |
| 20 | +// https://arxiv.org/pdf/1608.06993.pdf |
| 21 | + |
| 22 | +public struct DenseNet121: Layer { |
| 23 | + public var conv = Conv( |
| 24 | + filterSize: 7, |
| 25 | + stride: 2, |
| 26 | + inputFilterCount: 3, |
| 27 | + outputFilterCount: 64 |
| 28 | + ) |
| 29 | + public var maxpool = MaxPool2D<Float>( |
| 30 | + poolSize: (3, 3), |
| 31 | + strides: (2, 2), |
| 32 | + padding: .same |
| 33 | + ) |
| 34 | + public var denseBlock1 = DenseBlock(repetitionCount: 6, inputFilterCount: 64) |
| 35 | + public var transitionLayer1 = TransitionLayer(inputFilterCount: 256) |
| 36 | + public var denseBlock2 = DenseBlock(repetitionCount: 12, inputFilterCount: 128) |
| 37 | + public var transitionLayer2 = TransitionLayer(inputFilterCount: 512) |
| 38 | + public var denseBlock3 = DenseBlock(repetitionCount: 24, inputFilterCount: 256) |
| 39 | + public var transitionLayer3 = TransitionLayer(inputFilterCount: 1024) |
| 40 | + public var denseBlock4 = DenseBlock(repetitionCount: 16, inputFilterCount: 512) |
| 41 | + public var globalAvgPool = GlobalAvgPool2D<Float>() |
| 42 | + public var dense: Dense<Float> |
| 43 | + |
| 44 | + public init(classCount: Int) { |
| 45 | + dense = Dense(inputSize: 1024, outputSize: classCount, activation: softmax) |
| 46 | + } |
| 47 | + |
| 48 | + @differentiable |
| 49 | + public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 50 | + let inputLayer = input.sequenced(through: conv, maxpool) |
| 51 | + let level1 = inputLayer.sequenced(through: denseBlock1, transitionLayer1) |
| 52 | + let level2 = level1.sequenced(through: denseBlock2, transitionLayer2) |
| 53 | + let level3 = level2.sequenced(through: denseBlock3, transitionLayer3) |
| 54 | + let output = level3.sequenced(through: denseBlock4, globalAvgPool, dense) |
| 55 | + return output |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +extension DenseNet121 { |
| 60 | + public struct Conv: Layer { |
| 61 | + public var batchNorm: BatchNorm<Float> |
| 62 | + public var conv: Conv2D<Float> |
| 63 | + |
| 64 | + public init( |
| 65 | + filterSize: Int, |
| 66 | + stride: Int = 1, |
| 67 | + inputFilterCount: Int, |
| 68 | + outputFilterCount: Int |
| 69 | + ) { |
| 70 | + batchNorm = BatchNorm(featureCount: inputFilterCount) |
| 71 | + conv = Conv2D( |
| 72 | + filterShape: (filterSize, filterSize, inputFilterCount, outputFilterCount), |
| 73 | + strides: (stride, stride), |
| 74 | + padding: .same |
| 75 | + ) |
| 76 | + } |
| 77 | + |
| 78 | + @differentiable |
| 79 | + public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 80 | + conv(relu(batchNorm(input))) |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + /// A pair of a 1x1 `Conv` layer and a 3x3 `Conv` layer. |
| 85 | + public struct ConvPair: Layer { |
| 86 | + public var conv1x1: Conv |
| 87 | + public var conv3x3: Conv |
| 88 | + |
| 89 | + public init(inputFilterCount: Int, growthRate: Int) { |
| 90 | + conv1x1 = Conv( |
| 91 | + filterSize: 1, |
| 92 | + inputFilterCount: inputFilterCount, |
| 93 | + outputFilterCount: inputFilterCount * 2 |
| 94 | + ) |
| 95 | + conv3x3 = Conv( |
| 96 | + filterSize: 3, |
| 97 | + inputFilterCount: inputFilterCount * 2, |
| 98 | + outputFilterCount: growthRate |
| 99 | + ) |
| 100 | + } |
| 101 | + |
| 102 | + @differentiable |
| 103 | + public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 104 | + let conv1Output = conv1x1(input) |
| 105 | + let conv3Output = conv3x3(conv1Output) |
| 106 | + return conv3Output.concatenated(with: input, alongAxis: -1) |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + public struct DenseBlock: Layer { |
| 111 | + public var pairs: [ConvPair] = [] |
| 112 | + |
| 113 | + public init(repetitionCount: Int, growthRate: Int = 32, inputFilterCount: Int) { |
| 114 | + for i in 0..<repetitionCount { |
| 115 | + let filterCount = inputFilterCount + i * growthRate |
| 116 | + pairs.append(ConvPair(inputFilterCount: filterCount, growthRate: growthRate)) |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + @differentiable |
| 121 | + public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 122 | + pairs.differentiableReduce(input) { last, layer in |
| 123 | + layer(last) |
| 124 | + } |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + public struct TransitionLayer: Layer { |
| 129 | + public var conv: Conv |
| 130 | + public var pool: AvgPool2D<Float> |
| 131 | + |
| 132 | + public init(inputFilterCount: Int) { |
| 133 | + conv = Conv( |
| 134 | + filterSize: 1, |
| 135 | + inputFilterCount: inputFilterCount, |
| 136 | + outputFilterCount: inputFilterCount / 2 |
| 137 | + ) |
| 138 | + pool = AvgPool2D(poolSize: (2, 2), strides: (2, 2), padding: .same) |
| 139 | + } |
| 140 | + |
| 141 | + @differentiable |
| 142 | + public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { |
| 143 | + input.sequenced(through: conv, pool) |
| 144 | + } |
| 145 | + } |
| 146 | +} |
0 commit comments