Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add average, multiply and stack modes for BiRNNs #1101

Merged
merged 1 commit into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions Sources/TensorFlow/Layers/Recurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RecurrentLayerCell {
/// Concatenates two values.
@differentiable
public static func concatenate(_ lhs: Self, _ rhs: Self) -> Self {
// TODO: Remove workaround after https://github.com/tensorflow/swift-apis/issues/1087 is fixed.
// TODO(TF-1005): Remove workaround for differenting concatenated.
let concatCell = lhs.cell.concatenated(with: rhs.cell, alongAxis: -1)
let concatHidden = lhs.hidden.concatenated(with: rhs.hidden, alongAxis: -1)
let cell = concatCell.withDerivative { [shape = concatCell.shape] in
Expand All @@ -228,6 +228,33 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RecurrentLayerCell {
public static func sum(_ lhs: Self, _ rhs: Self) -> Self {
Self(cell: lhs.cell + rhs.cell, hidden: lhs.hidden + rhs.hidden)
}

/// Averages two values.
@differentiable
public static func average(_ lhs: Self, _ rhs: Self) -> Self {
Self(cell: (lhs.cell + rhs.cell) / 2, hidden: (lhs.hidden + rhs.hidden) / 2)
}

/// Multiplies two values.
@differentiable
public static func multiply(_ lhs: Self, _ rhs: Self) -> Self {
Self(cell: lhs.cell * rhs.cell, hidden: lhs.hidden * rhs.hidden)
}

/// Stack two values.
@differentiable
public static func stack(_ lhs: Self, _ rhs: Self) -> Self {
// TODO(TF-1005): Remove workaround for differenting stacking.
let stackCell = Tensor(stacking: [lhs.cell, rhs.cell])
let stackHidden = Tensor(stacking: [lhs.hidden, rhs.hidden])
let cell = stackCell.withDerivative { [shape = stackCell.shape] in
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
}
let hidden = stackHidden.withDerivative { [shape = stackHidden.shape] in
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
}
return Self(cell: cell, hidden: hidden)
}
}

/// Returns a zero-valued state with shape compatible with the provided input.
Expand Down Expand Up @@ -455,13 +482,25 @@ public protocol Mergeable: Differentiable, AdditiveArithmetic {
/// `Mergeable` (SR-13229).
@differentiable
static func sum(_ lhs: Self, _ rhs: Self) -> Self

/// Averages two values.
@differentiable
static func average(_ lhs: Self, _ rhs: Self) -> Self

/// Multiplies two values.
@differentiable
static func multiply(_ lhs: Self, _ rhs: Self) -> Self

/// Stack two values.
@differentiable
static func stack(_ lhs: Self, _ rhs: Self) -> Self
}

extension Tensor: Mergeable where Scalar: TensorFlowFloatingPoint {
/// Concatenates two tensors along last axis.
@differentiable
public static func concatenate(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
// TODO: Remove workaround after https://github.com/tensorflow/swift-apis/issues/1087 is fixed.
// TODO(TF-1005): Remove workaround for differenting concatenated.
let concat = lhs.concatenated(with: rhs, alongAxis: -1)
return concat.withDerivative { [shape = concat.shape] in
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
Expand All @@ -473,6 +512,28 @@ extension Tensor: Mergeable where Scalar: TensorFlowFloatingPoint {
public static func sum(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
lhs + rhs
}

/// Averages two values.
@differentiable
public static func average(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
(lhs + rhs) / 2
}

/// Multiplies two values.
@differentiable
public static func multiply(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
lhs * rhs
}

/// Stack two values.
@differentiable
public static func stack(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
// TODO(TF-1005): Remove workaround for differenting stacking.
let stack = Tensor(stacking: [lhs, rhs])
return stack.withDerivative { [shape = stack.shape] in
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
}
}
}

/// Concatenates two values.
Expand All @@ -493,6 +554,33 @@ public func sum<T: Mergeable>(
T.sum(first, second)
}

/// Averages two values.
@differentiable
public func average<T: Mergeable>(
_ first: T,
_ second: T
) -> T {
T.average(first, second)
}

/// Multiplies two values.
@differentiable
public func multiply<T: Mergeable>(
_ first: T,
_ second: T
) -> T {
T.multiply(first, second)
}

/// Stack two values.
@differentiable
public func stack<T: Mergeable>(
_ first: T,
_ second: T
) -> T {
T.stack(first, second)
}

public struct BidirectionalRecurrentLayer<Cell: RecurrentLayerCell>: Layer
where Cell.TimeStepOutput: Mergeable {
public typealias Input = [Cell.TimeStepInput]
Expand Down
Loading