Skip to content

Commit b30deb2

Browse files
authored
[TF] Refactor numeric reduction ops and fix rank issues. (#24042)
* Fix 'standardDeviation()' rank. * Refactor all numeric reduction ops.
1 parent 3e8a1de commit b30deb2

File tree

3 files changed

+241
-162
lines changed

3 files changed

+241
-162
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -569,25 +569,15 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
569569

570570
extension Tensor where Scalar : TensorFlowFloatingPoint {
571571
@inlinable
572-
func _vjpSum() -> (Tensor, (Tensor) -> Tensor) {
573-
return (sum(), { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
574-
}
575-
576-
@inlinable
577-
func _vjpMean() -> (Tensor, (Tensor) -> Tensor) {
578-
return (mean(), { [shape = shapeTensor, count = scalarCountTensor] in
579-
($0 / Tensor(count)).broadcast(toShape: shape)
580-
})
581-
}
582-
583-
@inlinable
584-
func _vjpSum(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
572+
func _vjpSum(alongAxes axes: Tensor<Int32>) -> (Tensor, (Tensor) -> Tensor) {
585573
let value = sum(alongAxes: axes)
586574
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
587575
}
588576

589577
@inlinable
590-
func _vjpSum(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
578+
func _vjpSum(
579+
squeezingAxes axes: Tensor<Int32>
580+
) -> (Tensor, (Tensor) -> Tensor) {
591581
let value = sum(squeezingAxes: axes)
592582
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
593583
}
@@ -602,20 +592,13 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
602592
}
603593

604594
@inlinable
605-
func _vjpMean(squeezingAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
595+
func _vjpMean(
596+
squeezingAxes axes: Tensor<Int32>
597+
) -> (Tensor, (Tensor) -> Tensor) {
606598
let value = mean(squeezingAxes: axes)
607-
return (value, { [shape = shapeTensor,
608-
count = axes.map { shape[$0] }.reduce(1, *)] in
609-
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
610-
})
611-
}
612-
613-
@inlinable
614-
func _vjpMean(alongAxes axes: [Int32]) -> (Tensor, (Tensor) -> Tensor) {
615-
let value = mean(alongAxes: axes)
616-
return (value, { [shape = shapeTensor,
617-
count = axes.map { shape[$0] }.reduce(1, *)] in
618-
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
599+
let count = Raw.gather(params: shapeTensor, indices: axes).product()
600+
return (value, { [shape = shapeTensor] in
601+
$0.broadcast(toShape: shape) / Tensor(count)
619602
})
620603
}
621604
}

0 commit comments

Comments
 (0)