@@ -138,6 +138,11 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
138
138
/// `Dense` implements the operation `activation(matmul(input, weight) + bias)`, where `weight` is
139
139
/// a weight matrix, `bias` is a bias vector, and `activation` is an element-wise activation
140
140
/// function.
141
+ ///
142
+ /// This layer also supports 3-D weight tensors with 2-D bias matrices. In this case the first
143
+ /// dimension of both is treated as the batch size that is aligned with the first dimension of
144
+ /// `input` and the batch variant of the `matmul(_:_:)` operation is used, thus using a different
145
+ /// weight and bias for each element in input batch.
141
146
@frozen
142
147
public struct Dense < Scalar: TensorFlowFloatingPoint > : Layer {
143
148
/// The weight matrix.
@@ -146,6 +151,8 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
146
151
public var bias : Tensor < Scalar >
147
152
/// The element-wise activation function.
148
153
@noDerivative public let activation : Activation
154
+ /// Indicates whether this is a batched dense layer.
155
+ @noDerivative internal let batched : Bool
149
156
150
157
/// The element-wise activation function type.
151
158
public typealias Activation = @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar >
@@ -155,9 +162,12 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
155
162
bias: Tensor < Scalar > ,
156
163
activation: @escaping Activation
157
164
) {
165
+ precondition ( weight. rank <= 3 , " The rank of the 'weight' tensor must be less than 4. " )
166
+ precondition ( bias. rank <= 2 , " The rank of the 'bias' tensor must be less than 3. " )
158
167
self . weight = weight
159
168
self . bias = bias
160
169
self . activation = activation
170
+ self . batched = weight. rank == 3
161
171
}
162
172
163
173
/// Returns the output obtained from applying the layer to the given input.
@@ -166,6 +176,10 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
166
176
/// - Returns: The output.
167
177
@differentiable
168
178
public func callAsFunction( _ input: Tensor < Scalar > ) -> Tensor < Scalar > {
179
+ if batched {
180
+ let hidden = matmul ( input. expandingShape ( at: 1 ) , weight)
181
+ return activation ( hidden. squeezingShape ( at: 1 ) + bias)
182
+ }
169
183
return activation ( matmul ( input, weight) + bias)
170
184
}
171
185
}
0 commit comments