Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit fdda63e

Browse files
authored
Enhanced the 'Dense' layer to support batch matrix multiplication. (#324)
1 parent 9a3f393 commit fdda63e

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

Sources/TensorFlow/Layers/Core.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
138138
/// `Dense` implements the operation `activation(matmul(input, weight) + bias)`, where `weight` is
139139
/// a weight matrix, `bias` is a bias vector, and `activation` is an element-wise activation
140140
/// function.
141+
///
142+
/// This layer also supports 3-D weight tensors with 2-D bias matrices. In this case the first
143+
/// dimension of both is treated as the batch size that is aligned with the first dimension of
144+
/// `input` and the batch variant of the `matmul(_:_:)` operation is used, thus using a different
145+
/// weight and bias for each element in input batch.
141146
@frozen
142147
public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
143148
/// The weight matrix.
@@ -146,6 +151,8 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
146151
public var bias: Tensor<Scalar>
147152
/// The element-wise activation function.
148153
@noDerivative public let activation: Activation
154+
/// Indicates whether this is a batched dense layer.
155+
@noDerivative internal let batched: Bool
149156

150157
/// The element-wise activation function type.
151158
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
@@ -155,9 +162,12 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
155162
bias: Tensor<Scalar>,
156163
activation: @escaping Activation
157164
) {
165+
precondition(weight.rank <= 3, "The rank of the 'weight' tensor must be less than 4.")
166+
precondition(bias.rank <= 2, "The rank of the 'bias' tensor must be less than 3.")
158167
self.weight = weight
159168
self.bias = bias
160169
self.activation = activation
170+
self.batched = weight.rank == 3
161171
}
162172

163173
/// Returns the output obtained from applying the layer to the given input.
@@ -166,6 +176,10 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
166176
/// - Returns: The output.
167177
@differentiable
168178
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
179+
if batched {
180+
let hidden = matmul(input.expandingShape(at: 1), weight)
181+
return activation(hidden.squeezingShape(at: 1) + bias)
182+
}
169183
return activation(matmul(input, weight) + bias)
170184
}
171185
}

0 commit comments

Comments
 (0)