Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit cc36830

Browse files
authored
Added support for negative 'axis' argument values to 'Tensor.unstacked(alongAxis:)'. (#351)
1 parent 268055e commit cc36830

File tree

2 files changed

+41
-29
lines changed

2 files changed

+41
-29
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ public extension TensorFlowScalar {
3434
}
3535

3636
public extension Tensor {
37-
/// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors. Unpacks
38-
/// `N` tensors from this tensor by chipping it along the `axis` dimension, where `N` is
39-
/// inferred from this tensor's shape. For example, given a tensor with shape `[A, B, C, D]`:
37+
/// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors.
38+
/// Unpacks `N` tensors from this tensor by chipping it along the `axis` dimension, where `N`
39+
/// is inferred from this tensor's shape. For example, given a tensor with shape
40+
/// `[A, B, C, D]`:
4041
///
4142
/// - If `axis == 0` then the `i`-th tensor in the returned array is the slice
4243
/// `self[i, :, :, :]` and each tensor in that array will have shape `[B, C, D]`.
@@ -51,14 +52,15 @@ public extension Tensor {
5152
/// - Parameters:
5253
/// - axis: Dimension along which to unstack. Negative values wrap around.
5354
///
54-
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
55-
/// provided tensors.
55+
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
56+
/// the provided tensors.
5657
///
5758
/// - Returns: Array containing the unstacked tensors.
5859
@inlinable
5960
@differentiable(vjp: _vjpUnstacked(alongAxis:) where Scalar: TensorFlowFloatingPoint)
6061
func unstacked(alongAxis axis: Int = 0) -> [Tensor] {
61-
return Raw.unpack(value: self, num: Int64(shape[axis]), axis: Int64(axis))
62+
let posAxis = axis < 0 ? axis + rank : axis
63+
return Raw.unpack(value: self, num: Int64(shape[posAxis]), axis: Int64(posAxis))
6264
}
6365

6466
/// Splits a tensor into multiple tensors. The tensor is split along dimension `axis` into
@@ -79,15 +81,14 @@ public extension Tensor {
7981
/// - axis: The dimension along which to split this tensor. Negative values wrap around.
8082
///
8183
/// - Precondition: `count` must divide the size of dimension `axis` evenly.
82-
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
83-
/// provided tensors.
84+
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
85+
/// the provided tensors.
8486
///
8587
/// - Returns: An array containing the tensors parts.
8688
@inlinable
8789
@differentiable(vjp: _vjpSplit(count:alongAxis:) where Scalar: TensorFlowFloatingPoint)
8890
func split(count: Int, alongAxis axis: Int = 0) -> [Tensor] {
89-
return Raw.split(
90-
splitDim: Tensor<Int32>(Int32(axis)), value: self, numSplit: Int64(count))
91+
Raw.split(splitDim: Tensor<Int32>(Int32(axis)), value: self, numSplit: Int64(count))
9192
}
9293

9394
/// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces.
@@ -109,16 +110,16 @@ public extension Tensor {
109110
/// - axis: Dimension along which to split this tensor. Negative values wrap around.
110111
///
111112
/// - Precondition: The values in `sizes` must add up to the size of dimension `axis`.
112-
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
113-
/// provided tensors.
113+
/// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
114+
/// the provided tensors.
114115
///
115116
/// - Returns: Array containing the tensors parts.
116117
@inlinable
117118
@differentiable(
118119
wrt: self,
119120
vjp: _vjpSplit(sizes:alongAxis:) where Scalar: TensorFlowFloatingPoint)
120121
func split(sizes: Tensor<Int32>, alongAxis axis: Int = 0) -> [Tensor] {
121-
return Raw.splitV(
122+
Raw.splitV(
122123
value: self,
123124
sizeSplits: sizes,
124125
splitDim: Tensor<Int32>(Int32(axis)),
@@ -136,15 +137,15 @@ public extension Tensor {
136137
@inlinable
137138
@differentiable(wrt: self, vjp: _vjpTiled(multiples:) where Scalar: TensorFlowFloatingPoint)
138139
func tiled(multiples: Tensor<Int32>) -> Tensor {
139-
return Raw.tile(self, multiples: multiples)
140+
Raw.tile(self, multiples: multiples)
140141
}
141142

142143
/// Reshape to the shape of the specified `Tensor`.
143144
/// - Precondition: The number of scalars matches the new shape.
144145
@inlinable
145146
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
146147
func reshaped<T>(like other: Tensor<T>) -> Tensor {
147-
return reshaped(toShape: other.shapeTensor)
148+
reshaped(toShape: other.shapeTensor)
148149
}
149150

150151
/// Reshape to the specified shape.
@@ -153,32 +154,30 @@ public extension Tensor {
153154
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
154155
func reshaped(to newShape: TensorShape) -> Tensor {
155156
// TODO(TF-433): Remove workaround for differentiating `map`.
156-
return reshaped(toShape: Tensor<Int32>({newShape.dimensions.map(Int32.init)}()))
157+
reshaped(toShape: Tensor<Int32>({newShape.dimensions.map(Int32.init)}()))
157158
}
158159

159160
/// Reshape to the specified `Tensor` representing a shape.
160161
/// - Precondition: The number of scalars matches the new shape.
161162
@inlinable
162-
@differentiable(
163-
wrt: self,
164-
vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint)
163+
@differentiable(wrt: self, vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint)
165164
func reshaped(toShape newShape: Tensor<Int32>) -> Tensor {
166-
return Raw.reshape(self, shape: newShape)
165+
Raw.reshape(self, shape: newShape)
167166
}
168167

169168
/// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order.
170169
@inlinable
171170
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
172171
func flattened() -> Tensor {
173-
return reshaped(to: [-1])
172+
reshaped(to: [-1])
174173
}
175174

176175
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the specified shape
177176
/// indices.
178177
@inlinable
179178
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
180179
func expandingShape(at axes: Int...) -> Tensor {
181-
return expandingShape(at: axes)
180+
expandingShape(at: axes)
182181
}
183182

184183
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
@@ -195,23 +194,23 @@ public extension Tensor {
195194
@inlinable
196195
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
197196
func rankLifted() -> Tensor {
198-
return expandingShape(at: 0)
197+
expandingShape(at: 0)
199198
}
200199

201200
/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are
202201
/// specified, then all dimensions of size 1 will be removed.
203202
@inlinable
204203
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
205204
func squeezingShape(at axes: Int...) -> Tensor {
206-
return squeezingShape(at: axes)
205+
squeezingShape(at: axes)
207206
}
208207

209208
/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are
210209
/// specified, then all dimensions of size 1 will be removed.
211210
@inlinable
212211
@differentiable(wrt: self, vjp: _vjpSqueezingShape(at:) where Scalar: TensorFlowFloatingPoint)
213212
func squeezingShape(at axes: [Int]) -> Tensor {
214-
return Raw.squeeze(self, squeezeDims: axes.map(Int32.init))
213+
Raw.squeeze(self, squeezeDims: axes.map(Int32.init))
215214
}
216215
}
217216

@@ -225,10 +224,8 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
225224
}
226225

227226
@inlinable
228-
func _vjpTiled(
229-
multiples: Tensor<Int32>
230-
) -> (Tensor, (Tensor) -> Tensor) {
231-
return (tiled(multiples: multiples), { [shape = shapeTensor] v in
227+
func _vjpTiled(multiples: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
228+
(tiled(multiples: multiples), { [shape = shapeTensor] v in
232229
let splitShape = Tensor<Int32>(stacking: [multiples, shape]).transposed().flattened()
233230
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(splitShape.scalarCount), stride: 2)
234231
return v.reshaped(toShape: splitShape).sum(squeezingAxes: axes)

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ final class TensorAutoDiffTests: XCTestCase {
153153
XCTAssertEqual(varianceGradAlongAxes(input), expected)
154154
}
155155

156+
// TODO: Uncomment once TF-653 is resolved.
157+
// func testTensorInitStacking() {
158+
// let a1 = Tensor<Float>([1, 2, 3, 4, 5])
159+
// let b1 = Tensor<Float>([6, 7, 8, 9, 10])
160+
// let a2 = Tensor<Float>([1, 1, 1, 1, 1])
161+
// let b2 = Tensor<Float>([1, 1, 1, 1, 1])
162+
// let grads = gradient(at: a2, b2) { a, b in
163+
// Tensor<Float>(stacking: [a1 * a, b1 * b], alongAxis: -1).sum()
164+
// }
165+
// XCTAssertEqual(a1, grads.0)
166+
// XCTAssertEqual(b1, grads.1)
167+
// }
168+
156169
func testExpandingShape() {
157170
func f1(a: Tensor<Float>) -> Tensor<Float> { a.expandingShape(at: 0).squared() }
158171
func f2(a: Tensor<Float>) -> Tensor<Float> { a.squared().expandingShape(at: 0) }
@@ -435,6 +448,8 @@ final class TensorAutoDiffTests: XCTestCase {
435448
("testSum", testSum),
436449
("testMean", testMean),
437450
("testVariance", testVariance),
451+
// TODO: Uncomment once TF-653 is resolved.
452+
// ("testTensorInitStacking", testTensorInitStacking),
438453
("testExpandingShape", testExpandingShape),
439454
("testSqueezingShape", testSqueezingShape),
440455
("testReshapedBackprop", testReshapedBackprop),

0 commit comments

Comments
 (0)