Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 723d17a

Browse files
SumanSudhirdan-zheng
authored andcommitted
added more preconditions (#614)
- `Tensor.replacing(with:where:)` - `Tensor.all({along,squeezing}Axes:)` - `Tensor.any({along,squeezing}Axes:)` - `Tensor.max({along,squeezing}Axes:)` - `Tensor.min({along,squeezing}Axes:)` - `Tensor.argmax({along,squeezing}Axes:)` - `Tensor.argmin({along,squeezing}Axes:)` - `Tensor.sum({along,squeezing}Axes:)` - `Tensor.product({along,squeezing}Axes:)` - `Tensor.mean({along,squeezing}Axes:)` - `Tensor.variance({along,squeezing}Axes:)` - `Tensor.cumulativeSum(alongAxis:exclusive:reverse:)` - `Tensor.cumulativeProduct(alongAxis:exclusive:reverse:)`
1 parent aba81ce commit 723d17a

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,13 @@ extension Tensor {
12191219
return axis >= -rank && axis < rank
12201220
}
12211221

1222+
/// Returns `true` if the given scalar tensor is in the range `[-rank, rank)`.
1223+
@usableFromInline
1224+
internal func isAxisInRange(_ axis: Tensor<Int32>) -> Bool {
1225+
precondition(axis.rank == 0, "Axis must have rank 0.")
1226+
return areAxesInRange(axis.scalars)
1227+
}
1228+
12221229
/// Returns `true` if all given axes are in the range `[-rank, rank)`.
12231230
@usableFromInline
12241231
internal func areAxesInRange<T: BinaryInteger>(_ axes: [T]) -> Bool {
@@ -1228,7 +1235,7 @@ extension Tensor {
12281235
/// Returns `true` if all scalars of the given 1-D tensor are in the range `[-rank, rank)`.
12291236
@usableFromInline
12301237
internal func areAxesInRange(_ axes: Tensor<Int32>) -> Bool {
1231-
precondition(axes.rank == 1, "Axes must have rank 1")
1238+
precondition(axes.rank == 1, "Axes must have rank 1.")
12321239
return areAxesInRange(axes.scalars)
12331240
}
12341241
}

Sources/TensorFlow/Operators/Math.swift

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,8 @@ public extension Tensor {
15431543
@inlinable
15441544
@differentiable(wrt: (self, other) where Scalar: TensorFlowFloatingPoint)
15451545
func replacing(with other: Tensor, where mask: Tensor<Bool>) -> Tensor {
1546-
_Raw.select(condition: mask, t: other, e: self)
1546+
precondition(self.shape == other.shape, "`self` and `other` must have the same shape.")
1547+
return _Raw.select(condition: mask, t: other, e: self)
15471548
}
15481549
}
15491550

@@ -1590,6 +1591,7 @@ public extension Tensor where Scalar == Bool {
15901591
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
15911592
@inlinable
15921593
func all(squeezingAxes axes: Int...) -> Tensor {
1594+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
15931595
let axes = axes.map(Int32.init)
15941596
return _Raw.all(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
15951597
}
@@ -1600,6 +1602,7 @@ public extension Tensor where Scalar == Bool {
16001602
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
16011603
@inlinable
16021604
func any(squeezingAxes axes: Int...) -> Tensor {
1605+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
16031606
let axes = axes.map(Int32.init)
16041607
return _Raw.any(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
16051608
}
@@ -1610,6 +1613,7 @@ public extension Tensor where Scalar == Bool {
16101613
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
16111614
@inlinable
16121615
func all(alongAxes axes: Int...) -> Tensor {
1616+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
16131617
let axes = axes.map(Int32.init)
16141618
return _Raw.all(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
16151619
}
@@ -1620,6 +1624,7 @@ public extension Tensor where Scalar == Bool {
16201624
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
16211625
@inlinable
16221626
func any(alongAxes axes: Int...) -> Tensor {
1627+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
16231628
let axes = axes.map(Int32.init)
16241629
return _Raw.any(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
16251630
}
@@ -1650,6 +1655,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
16501655
@inlinable
16511656
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
16521657
func max(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1658+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
16531659
return _Raw.max(self, reductionIndices: axes, keepDims: false)
16541660
}
16551661

@@ -1679,7 +1685,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
16791685
@inlinable
16801686
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
16811687
func min(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1682-
_Raw.min(self, reductionIndices: axes, keepDims: false)
1688+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1689+
return _Raw.min(self, reductionIndices: axes, keepDims: false)
16831690
}
16841691

16851692
/// Returns the minimum values along the specified axes. The reduced dimensions are removed.
@@ -1708,7 +1715,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
17081715
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
17091716
@inlinable
17101717
func argmax(squeezingAxis axis: Int) -> Tensor<Int32> {
1711-
_Raw.argMax(self, dimension: Tensor<Int32>(Int32(axis)))
1718+
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
1719+
return _Raw.argMax(self, dimension: Tensor<Int32>(Int32(axis)))
17121720
}
17131721

17141722
/// Returns the indices of the minimum values along the specified axes. The reduced dimensions
@@ -1717,7 +1725,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
17171725
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
17181726
@inlinable
17191727
func argmin(squeezingAxis axis: Int) -> Tensor<Int32> {
1720-
_Raw.argMin(self, dimension: Tensor<Int32>(Int32(axis)))
1728+
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
1729+
return _Raw.argMin(self, dimension: Tensor<Int32>(Int32(axis)))
17211730
}
17221731

17231732
/// Returns the minimum along the specified axes. The reduced dimensions are retained with
@@ -1727,7 +1736,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
17271736
@inlinable
17281737
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
17291738
func min(alongAxes axes: Tensor<Int32>) -> Tensor {
1730-
_Raw.min(self, reductionIndices: axes, keepDims: true)
1739+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1740+
return _Raw.min(self, reductionIndices: axes, keepDims: true)
17311741
}
17321742

17331743
/// Returns the minimum along the specified axes. The reduced dimensions are retained with
@@ -1759,7 +1769,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
17591769
@inlinable
17601770
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
17611771
func max(alongAxes axes: Tensor<Int32>) -> Tensor {
1762-
_Raw.max(self, reductionIndices: axes, keepDims: true)
1772+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1773+
return _Raw.max(self, reductionIndices: axes, keepDims: true)
17631774
}
17641775

17651776
/// Returns the minimum along the specified axes. The reduced dimensions are retained with
@@ -1886,7 +1897,8 @@ public extension Tensor where Scalar: Numeric {
18861897
@inlinable
18871898
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18881899
func sum(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1889-
_Raw.sum(self, reductionIndices: axes, keepDims: false)
1900+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1901+
return _Raw.sum(self, reductionIndices: axes, keepDims: false)
18901902
}
18911903

18921904
/// Returns the sum along the specified axes. The reduced dimensions are removed.
@@ -1921,7 +1933,8 @@ public extension Tensor where Scalar: Numeric {
19211933
@inlinable
19221934
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
19231935
func sum(alongAxes axes: Tensor<Int32>) -> Tensor {
1924-
_Raw.sum(self, reductionIndices: axes, keepDims: true)
1936+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1937+
return _Raw.sum(self, reductionIndices: axes, keepDims: true)
19251938
}
19261939

19271940
/// Returns the sum along the specified axes. The reduced dimensions are retained with value 1.
@@ -1953,7 +1966,8 @@ public extension Tensor where Scalar: Numeric {
19531966
@inlinable
19541967
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
19551968
func product(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1956-
_Raw.prod(self, reductionIndices: axes, keepDims: false)
1969+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1970+
return _Raw.prod(self, reductionIndices: axes, keepDims: false)
19571971
}
19581972

19591973
/// Returns the product along the specified axes. The reduced dimensions are removed.
@@ -1990,7 +2004,8 @@ public extension Tensor where Scalar: Numeric {
19902004
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
19912005
@inlinable
19922006
func product(alongAxes axes: Tensor<Int32>) -> Tensor {
1993-
_Raw.prod(self, reductionIndices: axes, keepDims: true)
2007+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2008+
return _Raw.prod(self, reductionIndices: axes, keepDims: true)
19942009
}
19952010

19962011
/// Returns the product along the specified axes. The reduced dimensions are retained with
@@ -2021,7 +2036,8 @@ public extension Tensor where Scalar: Numeric {
20212036
@inlinable
20222037
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
20232038
func mean(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2024-
_Raw.mean(self, reductionIndices: axes, keepDims: false)
2039+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2040+
return _Raw.mean(self, reductionIndices: axes, keepDims: false)
20252041
}
20262042

20272043
/// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed.
@@ -2057,7 +2073,8 @@ public extension Tensor where Scalar: Numeric {
20572073
@inlinable
20582074
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
20592075
func mean(alongAxes axes: Tensor<Int32>) -> Tensor {
2060-
_Raw.mean(self, reductionIndices: axes, keepDims: true)
2076+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2077+
return _Raw.mean(self, reductionIndices: axes, keepDims: true)
20612078
}
20622079

20632080
/// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained
@@ -2091,6 +2108,7 @@ public extension Tensor where Scalar: Numeric {
20912108
@inlinable
20922109
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
20932110
func variance(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2111+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
20942112
let squaredDiff = squaredDifference(self, mean(alongAxes: axes))
20952113
return squaredDiff.mean(squeezingAxes: axes)
20962114
}
@@ -2132,6 +2150,7 @@ public extension Tensor where Scalar: Numeric {
21322150
@inlinable
21332151
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
21342152
func variance(alongAxes axes: Tensor<Int32>) -> Tensor {
2153+
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
21352154
let squaredDiff = squaredDifference(self, mean(alongAxes: axes))
21362155
return squaredDiff.mean(alongAxes: axes)
21372156
}
@@ -2229,7 +2248,8 @@ public extension Tensor where Scalar: Numeric {
22292248
exclusive: Bool = false,
22302249
reverse: Bool = false
22312250
) -> Tensor {
2232-
_Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse)
2251+
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
2252+
return _Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse)
22332253
}
22342254

22352255
/// Returns the cumulative product of this tensor along the specified axis. By default, this
@@ -2294,6 +2314,7 @@ public extension Tensor where Scalar: Numeric {
22942314
/// - exclusive: Indicates whether to perform an exclusive cumulative product.
22952315
/// - reverse: Indicates whether to perform the cumulative product in reversed order.
22962316
/// - Returns: Result of the cumulative product operation.
2317+
/// - Precondition: `axis` must have rank `0`.
22972318
/// - Precondition: `axis` must be in the range `-rank..<rank`.
22982319
@inlinable
22992320
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
@@ -2302,7 +2323,8 @@ public extension Tensor where Scalar: Numeric {
23022323
exclusive: Bool = false,
23032324
reverse: Bool = false
23042325
) -> Tensor {
2305-
_Raw.cumprod(self, axis: axis, exclusive: exclusive, reverse: reverse)
2326+
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
2327+
return _Raw.cumprod(self, axis: axis, exclusive: exclusive, reverse: reverse)
23062328
}
23072329
}
23082330

0 commit comments

Comments
 (0)