Skip to content

Commit 059b5ba

Browse files
sguggerrxwei
authored andcommitted
---
yaml --- r: 312314 b: refs/heads/tensorflow-merge c: dc6415e h: refs/heads/master
1 parent 5be9f7e commit 059b5ba

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ refs/heads/chase-my-tail: 8bb91443a9e81bbfac92a2621a0af887a1da8dbf
13791379
refs/heads/consider-outer-alternatives: 708bac749ec60a22a79e2eefbe734f9488a7370d
13801380
refs/heads/revert-25740-oops-i-linked-it-again: fdd41aeb682fc488572bdc1cf71b2ff6997ba576
13811381
refs/heads/swift-5.1-branch-06-12-2019: e63b7b2d3b93c48232d386099d0ec525d21d8f8d
1382-
refs/heads/tensorflow-merge: 1ab9f04028722d7803b777b859bb4db83ead6a42
1382+
refs/heads/tensorflow-merge: dc6415ef0ea6e918c26bc04f4d89cea7a9bc1546
13831383
refs/heads/update-checkout-sha-info: 5832743c5c2a842976c42a508a4c6dcceefb0aef
13841384
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-12-a: 228f0448d9bb909aacbba4afcb7c600a405d15da
13851385
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-14-a: 922861a77b5fc2bf46bc917da70ceb15eef76836

branches/tensorflow-merge/stdlib/public/TensorFlow/Gradients.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,12 +553,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
553553
}
554554

555555
@inlinable
556-
func _vjpExpandingShape(
557-
at shapeIndex: Int
558-
) -> (Tensor, (Tensor) -> Tensor) {
559-
let value = expandingShape(at: shapeIndex)
556+
func _vjpExpandingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
557+
let value = self.expandingShape(at: axes)
560558
return (value, { v in
561-
v.squeezingShape(at: shapeIndex)
559+
v.squeezingShape(at: axes)
562560
})
563561
}
564562
}

branches/tensorflow-merge/stdlib/public/TensorFlow/Tensor.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -706,14 +706,24 @@ public extension Tensor {
706706
}
707707

708708
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
709-
/// specified shape index.
709+
/// specified shape indices.
710+
@inlinable @inline(__always)
711+
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
712+
func expandingShape(at axes: Int...) -> Tensor {
713+
return expandingShape(at: axes)
714+
}
715+
716+
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
717+
/// specified shape indices.
710718
@inlinable @inline(__always)
711719
@differentiable(
712720
wrt: self, vjp: _vjpExpandingShape(at:)
713721
where Scalar : TensorFlowFloatingPoint
714722
)
715-
func expandingShape(at shapeIndex: Int) -> Tensor {
716-
return Raw.expandDims(self, dim: Tensor<Int32>(Int32(shapeIndex)))
723+
func expandingShape(at axes: [Int]) -> Tensor {
724+
var res = self
725+
for i in axes { res = Raw.expandDims(res, dim: Tensor<Int32>(Int32(i))) }
726+
return res
717727
}
718728

719729
/// Remove the specified dimensions of size 1 from the shape of a tensor. If

branches/tensorflow-merge/test/TensorFlowRuntime/tensor.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,20 @@ TensorTests.testAllBackends("MLPClassifierStruct") {
674674
expectPointwiseNearlyEqual([0.816997], prediction.scalars)
675675
}
676676

677+
TensorTests.testAllBackends("ExpandingShape") {
678+
// 2 x 3 -> 1 x 2 x 1 x 3 x 1
679+
let matrix = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])
680+
let reshaped = matrix.expandingShape(at: 0,2,4)
681+
682+
expectEqual([1, 2, 1, 3, 1], reshaped.shape)
683+
expectEqual(Array(0..<6), reshaped.scalars)
684+
685+
// 1 x 2 x 1 x 3 x 1 -> 2 x 3
686+
let rereshaped = reshaped.squeezingShape(at: 0,2,4)
687+
expectEqual([2, 3], rereshaped.shape)
688+
expectEqual(Array(0..<6), rereshaped.scalars)
689+
}
690+
677691
TensorTests.testAllBackends("Reshape") {
678692
// 2 x 3 -> 1 x 3 x 1 x 2 x 1
679693
let matrix = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])

0 commit comments

Comments
 (0)