Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 33b3b5b

Browse files
authored
Added support for 'Tensor.clipped(min:max:)' and its VJP. (#361)
1 parent a3b3aa0 commit 33b3b5b

File tree

3 files changed

+70
-11
lines changed

3 files changed

+70
-11
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -675,16 +675,6 @@ public extension Tensor {
675675
}
676676
}
677677

678-
public extension Tensor where Scalar: Numeric {
679-
/// Returns a tensor by clipping scalars to a specified minimum and maximum.
680-
// FIXME: Define a derivative function.
681-
// @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
682-
@inlinable
683-
func clipped(min: Tensor, max: Tensor) -> Tensor {
684-
Raw.clipByValue(t: self, clipValueMin: min, clipValueMax: max)
685-
}
686-
}
687-
688678
//===------------------------------------------------------------------------------------------===//
689679
// Broadcasting
690680
//===------------------------------------------------------------------------------------------===//

Sources/TensorFlow/Operators/Math.swift

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public extension PointwiseMultiplicative {
3636
}
3737

3838
//===------------------------------------------------------------------------------------------===//
39-
// Generic elementary functions
39+
// Generic Elementary Functions
4040
//===------------------------------------------------------------------------------------------===//
4141

4242
extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
@@ -494,6 +494,58 @@ public extension Tensor where Scalar == Bool {
494494
}
495495
}
496496

497+
public extension Tensor where Scalar: TensorFlowNumeric {
498+
/// Returns `max(min(self, max), min)`.
499+
@inlinable
500+
@differentiable(vjp: _vjpClipped where Scalar: TensorFlowFloatingPoint)
501+
func clipped(min: Tensor, max: Tensor) -> Tensor {
502+
Raw.clipByValue(t: self, clipValueMin: min, clipValueMax: max)
503+
}
504+
505+
/// Returns `max(min(self, max), min)`.
506+
@inlinable
507+
@differentiable(wrt: (self, min) where Scalar: TensorFlowFloatingPoint)
508+
func clipped(min: Tensor, max: Scalar) -> Tensor {
509+
clipped(min: min, max: Tensor(max))
510+
}
511+
512+
/// Returns `max(min(self, max), min)`.
513+
@inlinable
514+
@differentiable(wrt: (self, max) where Scalar: TensorFlowFloatingPoint)
515+
func clipped(min: Scalar, max: Tensor) -> Tensor {
516+
clipped(min: Tensor(min), max: max)
517+
}
518+
519+
/// Returns `max(min(self, max), min)`.
520+
@inlinable
521+
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
522+
func clipped(min: Scalar, max: Scalar) -> Tensor {
523+
clipped(min: Tensor(min), max: Tensor(max))
524+
}
525+
}
526+
527+
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
528+
@inlinable
529+
func _vjpClipped(min: Tensor, max: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor, Tensor)) {
530+
(clipped(min: min, max: max), { v in
531+
let selfShape = self.shapeTensor
532+
let minShape = min.shapeTensor
533+
let maxShape = max.shapeTensor
534+
let zeros = Tensor(zerosLike: v)
535+
let minMask = self .< min
536+
let maxMask = self .> max
537+
let selfGradient = v.replacing(with: zeros, where: minMask.elementsLogicalOr(maxMask))
538+
let minGradient = zeros.replacing(with: v, where: minMask)
539+
let maxGradient = zeros.replacing(with: v, where: maxMask)
540+
let (selfAxes, minAxes) = Raw.broadcastGradientArgs(s0: selfShape, s1: minShape)
541+
let (_, maxAxes) = Raw.broadcastGradientArgs(s0: selfShape, s1: maxShape)
542+
return (selfGradient.sum(squeezingAxes: selfAxes).reshaped(toShape: selfShape),
543+
minGradient.sum(squeezingAxes: minAxes).reshaped(toShape: minShape),
544+
maxGradient.sum(squeezingAxes: maxAxes).reshaped(toShape: maxShape))
545+
})
546+
}
547+
}
548+
497549
//===------------------------------------------------------------------------------------------===//
498550
// Element-wise Unary Math Functions
499551
//===------------------------------------------------------------------------------------------===//

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ final class MathOperatorTests: XCTestCase {
5959
{ x in root(x, 3) }, { x in Float.root(x, 3) })
6060
}
6161

62+
func testClipping() {
63+
let x = Tensor<Float>([
64+
[0.45031791, 0.41123222, 0.53928467, 0.47167023, 0.15483777],
65+
[0.49975705, 0.71807549, 0.30396056, 0.26904690, 0.01404393],
66+
[0.16950939, 0.41085612, 0.79503016, 0.11977817, 0.99728241],
67+
[0.62510073, 0.17344792, 0.15406050, 0.40758517, 0.93683817],
68+
[0.15653343, 0.50502756, 0.99365925, 0.84617581, 0.17422509]])
69+
let clippedX = x.clipped(min: 0.2, max: 0.5)
70+
let expectedClippedX = Tensor<Float>([
71+
[0.45031791, 0.41123222, 0.50000000, 0.47167023, 0.20000000],
72+
[0.49975705, 0.50000000, 0.30396056, 0.26904690, 0.20000000],
73+
[0.20000000, 0.41085612, 0.50000000, 0.20000000, 0.50000000],
74+
[0.50000000, 0.20000000, 0.20000000, 0.40758517, 0.50000000],
75+
[0.20000000, 0.50000000, 0.50000000, 0.50000000, 0.20000000]])
76+
assertEqual(clippedX, expectedClippedX, accuracy: 0.0001)
77+
}
78+
6279
func testRsqrt() {
6380
let x = Tensor<Double>([1, 0.25, 1.0 / 9.0, 0.0625, 0.04])
6481
let target = Tensor<Double>([1, 2, 3, 4, 5]).sum()

0 commit comments

Comments
 (0)