Skip to content

Commit dc6415e

Browse files
sguggerrxwei
authored andcommitted
Make expandingShape take [Int] or Int... (#24191)
1 parent 1ab9f04 commit dc6415e

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,12 +553,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
553553
}
554554

555555
@inlinable
556-
func _vjpExpandingShape(
557-
at shapeIndex: Int
558-
) -> (Tensor, (Tensor) -> Tensor) {
559-
let value = expandingShape(at: shapeIndex)
556+
func _vjpExpandingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
557+
let value = self.expandingShape(at: axes)
560558
return (value, { v in
561-
v.squeezingShape(at: shapeIndex)
559+
v.squeezingShape(at: axes)
562560
})
563561
}
564562
}

stdlib/public/TensorFlow/Tensor.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -706,14 +706,24 @@ public extension Tensor {
706706
}
707707

708708
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
709-
/// specified shape index.
709+
/// specified shape indices.
710+
@inlinable @inline(__always)
711+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
712+
func expandingShape(at axes: Int...) -> Tensor {
713+
return expandingShape(at: axes)
714+
}
715+
716+
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
717+
/// specified shape indices.
710718
@inlinable @inline(__always)
711719
@differentiable(
712720
wrt: self, vjp: _vjpExpandingShape(at:)
713721
where Scalar : TensorFlowFloatingPoint
714722
)
715-
func expandingShape(at shapeIndex: Int) -> Tensor {
716-
return Raw.expandDims(self, dim: Tensor<Int32>(Int32(shapeIndex)))
723+
func expandingShape(at axes: [Int]) -> Tensor {
724+
var res = self
725+
for i in axes { res = Raw.expandDims(res, dim: Tensor<Int32>(Int32(i))) }
726+
return res
717727
}
718728

719729
/// Remove the specified dimensions of size 1 from the shape of a tensor. If

test/TensorFlowRuntime/tensor.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,20 @@ TensorTests.testAllBackends("MLPClassifierStruct") {
674674
expectPointwiseNearlyEqual([0.816997], prediction.scalars)
675675
}
676676

677+
TensorTests.testAllBackends("ExpandingShape") {
678+
// 2 x 3 -> 1 x 2 x 1 x 3 x 1
679+
let matrix = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])
680+
let reshaped = matrix.expandingShape(at: 0,2,4)
681+
682+
expectEqual([1, 2, 1, 3, 1], reshaped.shape)
683+
expectEqual(Array(0..<6), reshaped.scalars)
684+
685+
// 1 x 2 x 1 x 3 x 1 -> 2 x 3
686+
let rereshaped = reshaped.squeezingShape(at: 0,2,4)
687+
expectEqual([2, 3], rereshaped.shape)
688+
expectEqual(Array(0..<6), rereshaped.scalars)
689+
}
690+
677691
TensorTests.testAllBackends("Reshape") {
678692
// 2 x 3 -> 1 x 3 x 1 x 2 x 1
679693
let matrix = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])

0 commit comments

Comments
 (0)