@@ -1543,7 +1543,8 @@ public extension Tensor {
1543
1543
@inlinable
1544
1544
@differentiable ( wrt: ( self , other) where Scalar: TensorFlowFloatingPoint)
1545
1545
func replacing( with other: Tensor , where mask: Tensor < Bool > ) -> Tensor {
1546
- _Raw. select ( condition: mask, t: other, e: self )
1546
+ precondition ( self . shape == other. shape, " `self` and `other` must have the same shape. " )
1547
+ return _Raw. select ( condition: mask, t: other, e: self )
1547
1548
}
1548
1549
}
1549
1550
@@ -1590,6 +1591,7 @@ public extension Tensor where Scalar == Bool {
1590
1591
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1591
1592
@inlinable
1592
1593
func all( squeezingAxes axes: Int ... ) -> Tensor {
1594
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1593
1595
let axes = axes. map ( Int32 . init)
1594
1596
return _Raw. all ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: false )
1595
1597
}
@@ -1600,6 +1602,7 @@ public extension Tensor where Scalar == Bool {
1600
1602
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1601
1603
@inlinable
1602
1604
func any( squeezingAxes axes: Int ... ) -> Tensor {
1605
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1603
1606
let axes = axes. map ( Int32 . init)
1604
1607
return _Raw. any ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: false )
1605
1608
}
@@ -1610,6 +1613,7 @@ public extension Tensor where Scalar == Bool {
1610
1613
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1611
1614
@inlinable
1612
1615
func all( alongAxes axes: Int ... ) -> Tensor {
1616
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1613
1617
let axes = axes. map ( Int32 . init)
1614
1618
return _Raw. all ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: true )
1615
1619
}
@@ -1620,6 +1624,7 @@ public extension Tensor where Scalar == Bool {
1620
1624
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1621
1625
@inlinable
1622
1626
func any( alongAxes axes: Int ... ) -> Tensor {
1627
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1623
1628
let axes = axes. map ( Int32 . init)
1624
1629
return _Raw. any ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: true )
1625
1630
}
@@ -1650,6 +1655,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
1650
1655
@inlinable
1651
1656
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1652
1657
func max( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
1658
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1653
1659
return _Raw. max ( self , reductionIndices: axes, keepDims: false )
1654
1660
}
1655
1661
@@ -1679,7 +1685,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
1679
1685
@inlinable
1680
1686
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1681
1687
func min( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
1682
- _Raw. min ( self , reductionIndices: axes, keepDims: false )
1688
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1689
+ return _Raw. min ( self , reductionIndices: axes, keepDims: false )
1683
1690
}
1684
1691
1685
1692
/// Returns the minimum values along the specified axes. The reduced dimensions are removed.
@@ -1708,7 +1715,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
1708
1715
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1709
1716
@inlinable
1710
1717
func argmax( squeezingAxis axis: Int ) -> Tensor < Int32 > {
1711
- _Raw. argMax ( self , dimension: Tensor < Int32 > ( Int32 ( axis) ) )
1718
+ precondition ( isAxisInRange ( axis) , " Axis must be in the range `[-rank, rank)`. " )
1719
+ return _Raw. argMax ( self , dimension: Tensor < Int32 > ( Int32 ( axis) ) )
1712
1720
}
1713
1721
1714
1722
/// Returns the indices of the minimum values along the specified axes. The reduced dimensions
@@ -1717,7 +1725,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
1717
1725
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1718
1726
@inlinable
1719
1727
func argmin( squeezingAxis axis: Int ) -> Tensor < Int32 > {
1720
- _Raw. argMin ( self , dimension: Tensor < Int32 > ( Int32 ( axis) ) )
1728
+ precondition ( isAxisInRange ( axis) , " Axis must be in the range `[-rank, rank)`. " )
1729
+ return _Raw. argMin ( self , dimension: Tensor < Int32 > ( Int32 ( axis) ) )
1721
1730
}
1722
1731
1723
1732
/// Returns the minimum along the specified axes. The reduced dimensions are retained with
@@ -1727,7 +1736,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
1727
1736
@inlinable
1728
1737
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1729
1738
func min( alongAxes axes: Tensor < Int32 > ) -> Tensor {
1730
- _Raw. min ( self , reductionIndices: axes, keepDims: true )
1739
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1740
+ return _Raw. min ( self , reductionIndices: axes, keepDims: true )
1731
1741
}
1732
1742
1733
1743
/// Returns the minimum along the specified axes. The reduced dimensions are retained with
@@ -1759,7 +1769,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
1759
1769
@inlinable
1760
1770
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1761
1771
func max( alongAxes axes: Tensor < Int32 > ) -> Tensor {
1762
- _Raw. max ( self , reductionIndices: axes, keepDims: true )
1772
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1773
+ return _Raw. max ( self , reductionIndices: axes, keepDims: true )
1763
1774
}
1764
1775
1765
1776
/// Returns the minimum along the specified axes. The reduced dimensions are retained with
@@ -1886,7 +1897,8 @@ public extension Tensor where Scalar: Numeric {
1886
1897
@inlinable
1887
1898
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1888
1899
func sum( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
1889
- _Raw. sum ( self , reductionIndices: axes, keepDims: false )
1900
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1901
+ return _Raw. sum ( self , reductionIndices: axes, keepDims: false )
1890
1902
}
1891
1903
1892
1904
/// Returns the sum along the specified axes. The reduced dimensions are removed.
@@ -1921,7 +1933,8 @@ public extension Tensor where Scalar: Numeric {
1921
1933
@inlinable
1922
1934
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1923
1935
func sum( alongAxes axes: Tensor < Int32 > ) -> Tensor {
1924
- _Raw. sum ( self , reductionIndices: axes, keepDims: true )
1936
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1937
+ return _Raw. sum ( self , reductionIndices: axes, keepDims: true )
1925
1938
}
1926
1939
1927
1940
/// Returns the sum along the specified axes. The reduced dimensions are retained with value 1.
@@ -1953,7 +1966,8 @@ public extension Tensor where Scalar: Numeric {
1953
1966
@inlinable
1954
1967
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1955
1968
func product( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
1956
- _Raw. prod ( self , reductionIndices: axes, keepDims: false )
1969
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1970
+ return _Raw. prod ( self , reductionIndices: axes, keepDims: false )
1957
1971
}
1958
1972
1959
1973
/// Returns the product along the specified axes. The reduced dimensions are removed.
@@ -1990,7 +2004,8 @@ public extension Tensor where Scalar: Numeric {
1990
2004
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1991
2005
@inlinable
1992
2006
func product( alongAxes axes: Tensor < Int32 > ) -> Tensor {
1993
- _Raw. prod ( self , reductionIndices: axes, keepDims: true )
2007
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2008
+ return _Raw. prod ( self , reductionIndices: axes, keepDims: true )
1994
2009
}
1995
2010
1996
2011
/// Returns the product along the specified axes. The reduced dimensions are retained with
@@ -2021,7 +2036,8 @@ public extension Tensor where Scalar: Numeric {
2021
2036
@inlinable
2022
2037
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2023
2038
func mean( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2024
- _Raw. mean ( self , reductionIndices: axes, keepDims: false )
2039
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2040
+ return _Raw. mean ( self , reductionIndices: axes, keepDims: false )
2025
2041
}
2026
2042
2027
2043
/// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed.
@@ -2057,7 +2073,8 @@ public extension Tensor where Scalar: Numeric {
2057
2073
@inlinable
2058
2074
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2059
2075
func mean( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2060
- _Raw. mean ( self , reductionIndices: axes, keepDims: true )
2076
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2077
+ return _Raw. mean ( self , reductionIndices: axes, keepDims: true )
2061
2078
}
2062
2079
2063
2080
/// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained
@@ -2091,6 +2108,7 @@ public extension Tensor where Scalar: Numeric {
2091
2108
@inlinable
2092
2109
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2093
2110
func variance( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2111
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2094
2112
let squaredDiff = squaredDifference ( self , mean ( alongAxes: axes) )
2095
2113
return squaredDiff. mean ( squeezingAxes: axes)
2096
2114
}
@@ -2132,6 +2150,7 @@ public extension Tensor where Scalar: Numeric {
2132
2150
@inlinable
2133
2151
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2134
2152
func variance( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2153
+ precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2135
2154
let squaredDiff = squaredDifference ( self , mean ( alongAxes: axes) )
2136
2155
return squaredDiff. mean ( alongAxes: axes)
2137
2156
}
@@ -2229,7 +2248,8 @@ public extension Tensor where Scalar: Numeric {
2229
2248
exclusive: Bool = false ,
2230
2249
reverse: Bool = false
2231
2250
) -> Tensor {
2232
- _Raw. cumsum ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2251
+ precondition ( isAxisInRange ( axis) , " Axis must be in the range `[-rank, rank)`. " )
2252
+ return _Raw. cumsum ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2233
2253
}
2234
2254
2235
2255
/// Returns the cumulative product of this tensor along the specified axis. By default, this
@@ -2294,6 +2314,7 @@ public extension Tensor where Scalar: Numeric {
2294
2314
/// - exclusive: Indicates whether to perform an exclusive cumulative product.
2295
2315
/// - reverse: Indicates whether to perform the cumulative product in reversed order.
2296
2316
/// - Returns: Result of the cumulative product operation.
2317
+ /// - Precondition: `axis` must have rank `0`.
2297
2318
/// - Precondition: `axis` must be in the range `-rank..<rank`.
2298
2319
@inlinable
2299
2320
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
@@ -2302,7 +2323,8 @@ public extension Tensor where Scalar: Numeric {
2302
2323
exclusive: Bool = false ,
2303
2324
reverse: Bool = false
2304
2325
) -> Tensor {
2305
- _Raw. cumprod ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2326
+ precondition ( isAxisInRange ( axis) , " Axis must be in the range `[-rank, rank)`. " )
2327
+ return _Raw. cumprod ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2306
2328
}
2307
2329
}
2308
2330
0 commit comments