Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Added support for negative 'axis' argument values to 'Tensor.unstacked(alongAxis:)'. #351

Merged
merged 4 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions Sources/TensorFlow/Operators/Basic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ public extension TensorFlowScalar {
}

public extension Tensor {
/// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors. Unpacks
/// `N` tensors from this tensor by chipping it along the `axis` dimension, where `N` is
/// inferred from this tensor's shape. For example, given a tensor with shape `[A, B, C, D]`:
/// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors.
/// Unpacks `N` tensors from this tensor by chipping it along the `axis` dimension, where `N`
/// is inferred from this tensor's shape. For example, given a tensor with shape
/// `[A, B, C, D]`:
///
/// - If `axis == 0` then the `i`-th tensor in the returned array is the slice
/// `self[i, :, :, :]` and each tensor in that array will have shape `[B, C, D]`.
Expand All @@ -51,14 +52,15 @@ public extension Tensor {
/// - Parameters:
/// - axis: Dimension along which to unstack. Negative values wrap around.
///
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
/// provided tensors.
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
/// the provided tensors.
///
/// - Returns: Array containing the unstacked tensors.
@inlinable
@differentiable(vjp: _vjpUnstacked(alongAxis:) where Scalar: TensorFlowFloatingPoint)
func unstacked(alongAxis axis: Int = 0) -> [Tensor] {
return Raw.unpack(value: self, num: Int64(shape[axis]), axis: Int64(axis))
let posAxis = axis < 0 ? axis + rank : axis
return Raw.unpack(value: self, num: Int64(shape[posAxis]), axis: Int64(posAxis))
}

/// Splits a tensor into multiple tensors. The tensor is split along dimension `axis` into
Expand All @@ -79,15 +81,14 @@ public extension Tensor {
/// - axis: The dimension along which to split this tensor. Negative values wrap around.
///
/// - Precondition: `count` must divide the size of dimension `axis` evenly.
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
/// provided tensors.
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
/// the provided tensors.
///
/// - Returns: An array containing the tensors parts.
@inlinable
@differentiable(vjp: _vjpSplit(count:alongAxis:) where Scalar: TensorFlowFloatingPoint)
func split(count: Int, alongAxis axis: Int = 0) -> [Tensor] {
return Raw.split(
splitDim: Tensor<Int32>(Int32(axis)), value: self, numSplit: Int64(count))
Raw.split(splitDim: Tensor<Int32>(Int32(axis)), value: self, numSplit: Int64(count))
}

/// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces.
Expand All @@ -109,16 +110,16 @@ public extension Tensor {
/// - axis: Dimension along which to split this tensor. Negative values wrap around.
///
/// - Precondition: The values in `sizes` must add up to the size of dimension `axis`.
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
/// provided tensors.
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
/// the provided tensors.
///
/// - Returns: Array containing the tensors parts.
@inlinable
@differentiable(
wrt: self,
vjp: _vjpSplit(sizes:alongAxis:) where Scalar: TensorFlowFloatingPoint)
func split(sizes: Tensor<Int32>, alongAxis axis: Int = 0) -> [Tensor] {
return Raw.splitV(
Raw.splitV(
value: self,
sizeSplits: sizes,
splitDim: Tensor<Int32>(Int32(axis)),
Expand All @@ -136,15 +137,15 @@ public extension Tensor {
@inlinable
@differentiable(wrt: self, vjp: _vjpTiled(multiples:) where Scalar: TensorFlowFloatingPoint)
func tiled(multiples: Tensor<Int32>) -> Tensor {
return Raw.tile(self, multiples: multiples)
Raw.tile(self, multiples: multiples)
}

/// Reshape to the shape of the specified `Tensor`.
/// - Precondition: The number of scalars matches the new shape.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func reshaped<T>(like other: Tensor<T>) -> Tensor {
return reshaped(toShape: other.shapeTensor)
reshaped(toShape: other.shapeTensor)
}

/// Reshape to the specified shape.
Expand All @@ -153,32 +154,30 @@ public extension Tensor {
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func reshaped(to newShape: TensorShape) -> Tensor {
// TODO(TF-433): Remove workaround for differentiating `map`.
return reshaped(toShape: Tensor<Int32>({newShape.dimensions.map(Int32.init)}()))
reshaped(toShape: Tensor<Int32>({newShape.dimensions.map(Int32.init)}()))
}

/// Reshape to the specified `Tensor` representing a shape.
/// - Precondition: The number of scalars matches the new shape.
@inlinable
@differentiable(
wrt: self,
vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint)
@differentiable(wrt: self, vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint)
func reshaped(toShape newShape: Tensor<Int32>) -> Tensor {
return Raw.reshape(self, shape: newShape)
Raw.reshape(self, shape: newShape)
}

/// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func flattened() -> Tensor {
return reshaped(to: [-1])
reshaped(to: [-1])
}

/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the specified shape
/// indices.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func expandingShape(at axes: Int...) -> Tensor {
return expandingShape(at: axes)
expandingShape(at: axes)
}

/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
Expand All @@ -195,23 +194,23 @@ public extension Tensor {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func rankLifted() -> Tensor {
return expandingShape(at: 0)
expandingShape(at: 0)
}

/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are
/// specified, then all dimensions of size 1 will be removed.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func squeezingShape(at axes: Int...) -> Tensor {
return squeezingShape(at: axes)
squeezingShape(at: axes)
}

/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are
/// specified, then all dimensions of size 1 will be removed.
@inlinable
@differentiable(wrt: self, vjp: _vjpSqueezingShape(at:) where Scalar: TensorFlowFloatingPoint)
func squeezingShape(at axes: [Int]) -> Tensor {
return Raw.squeeze(self, squeezeDims: axes.map(Int32.init))
Raw.squeeze(self, squeezeDims: axes.map(Int32.init))
}
}

Expand All @@ -225,10 +224,8 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
}

@inlinable
func _vjpTiled(
multiples: Tensor<Int32>
) -> (Tensor, (Tensor) -> Tensor) {
return (tiled(multiples: multiples), { [shape = shapeTensor] v in
func _vjpTiled(multiples: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
(tiled(multiples: multiples), { [shape = shapeTensor] v in
let splitShape = Tensor<Int32>(stacking: [multiples, shape]).transposed().flattened()
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(splitShape.scalarCount), stride: 2)
return v.reshaped(toShape: splitShape).sum(squeezingAxes: axes)
Expand Down
15 changes: 15 additions & 0 deletions Tests/TensorFlowTests/TensorAutoDiffTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ final class TensorAutoDiffTests: XCTestCase {
XCTAssertEqual(varianceGradAlongAxes(input), expected)
}

// TODO: Uncomment once TF-653 is resolved.
// func testTensorInitStacking() {
// let a1 = Tensor<Float>([1, 2, 3, 4, 5])
// let b1 = Tensor<Float>([6, 7, 8, 9, 10])
// let a2 = Tensor<Float>([1, 1, 1, 1, 1])
// let b2 = Tensor<Float>([1, 1, 1, 1, 1])
// let grads = gradient(at: a2, b2) { a, b in
// Tensor<Float>(stacking: [a1 * a, b1 * b], alongAxis: -1).sum()
// }
// XCTAssertEqual(a1, grads.0)
// XCTAssertEqual(b1, grads.1)
// }

func testExpandingShape() {
func f1(a: Tensor<Float>) -> Tensor<Float> { a.expandingShape(at: 0).squared() }
func f2(a: Tensor<Float>) -> Tensor<Float> { a.squared().expandingShape(at: 0) }
Expand Down Expand Up @@ -435,6 +448,8 @@ final class TensorAutoDiffTests: XCTestCase {
("testSum", testSum),
("testMean", testMean),
("testVariance", testVariance),
// TODO: Uncomment once TF-653 is resolved.
// ("testTensorInitStacking", testTensorInitStacking),
("testExpandingShape", testExpandingShape),
("testSqueezingShape", testSqueezingShape),
("testReshapedBackprop", testReshapedBackprop),
Expand Down