Skip to content

Commit 85f1da5

Browse files
authored
---
yaml --- r: 312286 b: refs/heads/tensorflow-merge c: b30deb2 h: refs/heads/master
1 parent cc5b595 commit 85f1da5

File tree

4 files changed

+242
-163
lines changed

4 files changed

+242
-163
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ refs/heads/chase-my-tail: 8bb91443a9e81bbfac92a2621a0af887a1da8dbf
13791379
refs/heads/consider-outer-alternatives: 708bac749ec60a22a79e2eefbe734f9488a7370d
13801380
refs/heads/revert-25740-oops-i-linked-it-again: fdd41aeb682fc488572bdc1cf71b2ff6997ba576
13811381
refs/heads/swift-5.1-branch-06-12-2019: e63b7b2d3b93c48232d386099d0ec525d21d8f8d
1382-
refs/heads/tensorflow-merge: 3e8a1deb09ebecd6c93d44b02874a8a3c89fe5dc
1382+
refs/heads/tensorflow-merge: b30deb2d3d1720ac4e0e2041fa9f4c63d51fcb3b
13831383
refs/heads/update-checkout-sha-info: 5832743c5c2a842976c42a508a4c6dcceefb0aef
13841384
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-12-a: 228f0448d9bb909aacbba4afcb7c600a405d15da
13851385
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-14-a: 922861a77b5fc2bf46bc917da70ceb15eef76836

branches/tensorflow-merge/stdlib/public/TensorFlow/Gradients.swift

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -569,25 +569,15 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
569569

570570
extension Tensor where Scalar : TensorFlowFloatingPoint {
571571
@inlinable
572-
func _vjpSum() -> (Tensor, (Tensor) -> Tensor) {
573-
return (sum(), { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
574-
}
575-
576-
@inlinable
577-
func _vjpMean() -> (Tensor, (Tensor) -> Tensor) {
578-
return (mean(), { [shape = shapeTensor, count = scalarCountTensor] in
579-
($0 / Tensor(count)).broadcast(toShape: shape)
580-
})
581-
}
582-
583-
@inlinable
584-
func _vjpSum(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
572+
func _vjpSum(alongAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
585573
let value = sum(alongAxes: axes)
586574
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
587575
}
588576

589577
@inlinable
590-
func _vjpSum(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
578+
func _vjpSum(
579+
squeezingAxes axes: Tensor<Int32>
580+
) -> (Tensor, (Tensor) -> Tensor) {
591581
let value = sum(squeezingAxes: axes)
592582
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
593583
}
@@ -602,20 +592,13 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
602592
}
603593

604594
@inlinable
605-
func _vjpMean(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
595+
func _vjpMean(
596+
squeezingAxes axes: Tensor<Int32>
597+
) -> (Tensor, (Tensor) -> Tensor) {
606598
let value = mean(squeezingAxes: axes)
607-
return (value, { [shape = shapeTensor,
608-
count = axes.map { shape[$0] }.reduce(1, *)] in
609-
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
610-
})
611-
}
612-
613-
@inlinable
614-
func _vjpMean(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
615-
let value = mean(alongAxes: axes)
616-
return (value, { [shape = shapeTensor,
617-
count = axes.map { shape[$0] }.reduce(1, *)] in
618-
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
599+
let count = Raw.gather(params: shapeTensor, indices: axes).product()
600+
return (value, { [shape = shapeTensor] in
601+
$0.broadcast(toShape: shape) / Tensor(count)
619602
})
620603
}
621604
}

0 commit comments

Comments
 (0)