Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

added more precondition #614

Merged
merged 5 commits into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Sources/TensorFlow/Operators/Basic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,13 @@ extension Tensor {
return axis >= -rank && axis < rank
}

/// Returns `true` if the given scalar tensor is in the range `[-rank, rank)`.
@usableFromInline
internal func isAxisInRange(_ axis: Tensor<Int32>) -> Bool {
precondition(axis.rank == 0, "Axis must have rank 0.")
return areAxesInRange(axis.scalars)
}

/// Returns `true` if all given axes are in the range `[-rank, rank)`.
@usableFromInline
internal func areAxesInRange<T: BinaryInteger>(_ axes: [T]) -> Bool {
Expand All @@ -1228,7 +1235,7 @@ extension Tensor {
/// Returns `true` if all scalars of the given 1-D tensor are in the range `[-rank, rank)`.
@usableFromInline
internal func areAxesInRange(_ axes: Tensor<Int32>) -> Bool {
precondition(axes.rank == 1, "Axes must have rank 1")
precondition(axes.rank == 1, "Axes must have rank 1.")
return areAxesInRange(axes.scalars)
}
}
50 changes: 36 additions & 14 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,8 @@ public extension Tensor {
@inlinable
@differentiable(wrt: (self, other) where Scalar: TensorFlowFloatingPoint)
func replacing(with other: Tensor, where mask: Tensor<Bool>) -> Tensor {
_Raw.select(condition: mask, t: other, e: self)
precondition(self.shape == other.shape, "`self` and `other` must have the same shape.")
return _Raw.select(condition: mask, t: other, e: self)
}
}

Expand Down Expand Up @@ -1590,6 +1591,7 @@ public extension Tensor where Scalar == Bool {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func all(squeezingAxes axes: Int...) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
let axes = axes.map(Int32.init)
return _Raw.all(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
}
Expand All @@ -1600,6 +1602,7 @@ public extension Tensor where Scalar == Bool {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func any(squeezingAxes axes: Int...) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
let axes = axes.map(Int32.init)
return _Raw.any(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
}
Expand All @@ -1610,6 +1613,7 @@ public extension Tensor where Scalar == Bool {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func all(alongAxes axes: Int...) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
let axes = axes.map(Int32.init)
return _Raw.all(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
}
Expand All @@ -1620,6 +1624,7 @@ public extension Tensor where Scalar == Bool {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func any(alongAxes axes: Int...) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
let axes = axes.map(Int32.init)
return _Raw.any(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
}
Expand Down Expand Up @@ -1650,6 +1655,7 @@ public extension Tensor where Scalar: Numeric & Comparable {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func max(squeezingAxes axes: Tensor<Int32>) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.max(self, reductionIndices: axes, keepDims: false)
}

Expand Down Expand Up @@ -1679,7 +1685,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func min(squeezingAxes axes: Tensor<Int32>) -> Tensor {
_Raw.min(self, reductionIndices: axes, keepDims: false)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.min(self, reductionIndices: axes, keepDims: false)
}

/// Returns the minimum values along the specified axes. The reduced dimensions are removed.
Expand Down Expand Up @@ -1708,7 +1715,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func argmax(squeezingAxis axis: Int) -> Tensor<Int32> {
_Raw.argMax(self, dimension: Tensor<Int32>(Int32(axis)))
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
return _Raw.argMax(self, dimension: Tensor<Int32>(Int32(axis)))
}

/// Returns the indices of the minimum values along the specified axes. The reduced dimensions
Expand All @@ -1717,7 +1725,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func argmin(squeezingAxis axis: Int) -> Tensor<Int32> {
_Raw.argMin(self, dimension: Tensor<Int32>(Int32(axis)))
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
return _Raw.argMin(self, dimension: Tensor<Int32>(Int32(axis)))
}

/// Returns the minimum along the specified axes. The reduced dimensions are retained with
Expand All @@ -1727,7 +1736,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func min(alongAxes axes: Tensor<Int32>) -> Tensor {
_Raw.min(self, reductionIndices: axes, keepDims: true)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.min(self, reductionIndices: axes, keepDims: true)
}

/// Returns the minimum along the specified axes. The reduced dimensions are retained with
Expand Down Expand Up @@ -1759,7 +1769,8 @@ public extension Tensor where Scalar: Numeric & Comparable {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func max(alongAxes axes: Tensor<Int32>) -> Tensor {
_Raw.max(self, reductionIndices: axes, keepDims: true)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.max(self, reductionIndices: axes, keepDims: true)
}

/// Returns the minimum along the specified axes. The reduced dimensions are retained with
Expand Down Expand Up @@ -1886,7 +1897,8 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func sum(squeezingAxes axes: Tensor<Int32>) -> Tensor {
_Raw.sum(self, reductionIndices: axes, keepDims: false)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.sum(self, reductionIndices: axes, keepDims: false)
}

/// Returns the sum along the specified axes. The reduced dimensions are removed.
Expand Down Expand Up @@ -1921,7 +1933,8 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func sum(alongAxes axes: Tensor<Int32>) -> Tensor {
_Raw.sum(self, reductionIndices: axes, keepDims: true)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.sum(self, reductionIndices: axes, keepDims: true)
}

/// Returns the sum along the specified axes. The reduced dimensions are retained with value 1.
Expand Down Expand Up @@ -1953,7 +1966,8 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func product(squeezingAxes axes: Tensor<Int32>) -> Tensor {
_Raw.prod(self, reductionIndices: axes, keepDims: false)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.prod(self, reductionIndices: axes, keepDims: false)
}

/// Returns the product along the specified axes. The reduced dimensions are removed.
Expand Down Expand Up @@ -1990,7 +2004,8 @@ public extension Tensor where Scalar: Numeric {
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
@inlinable
func product(alongAxes axes: Tensor<Int32>) -> Tensor {
_Raw.prod(self, reductionIndices: axes, keepDims: true)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.prod(self, reductionIndices: axes, keepDims: true)
}

/// Returns the product along the specified axes. The reduced dimensions are retained with
Expand Down Expand Up @@ -2021,7 +2036,8 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func mean(squeezingAxes axes: Tensor<Int32>) -> Tensor {
_Raw.mean(self, reductionIndices: axes, keepDims: false)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.mean(self, reductionIndices: axes, keepDims: false)
}

/// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed.
Expand Down Expand Up @@ -2057,7 +2073,8 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func mean(alongAxes axes: Tensor<Int32>) -> Tensor {
_Raw.mean(self, reductionIndices: axes, keepDims: true)
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
return _Raw.mean(self, reductionIndices: axes, keepDims: true)
}

/// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained
Expand Down Expand Up @@ -2091,6 +2108,7 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func variance(squeezingAxes axes: Tensor<Int32>) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
let squaredDiff = squaredDifference(self, mean(alongAxes: axes))
return squaredDiff.mean(squeezingAxes: axes)
}
Expand Down Expand Up @@ -2132,6 +2150,7 @@ public extension Tensor where Scalar: Numeric {
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
func variance(alongAxes axes: Tensor<Int32>) -> Tensor {
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
let squaredDiff = squaredDifference(self, mean(alongAxes: axes))
return squaredDiff.mean(alongAxes: axes)
}
Expand Down Expand Up @@ -2229,7 +2248,8 @@ public extension Tensor where Scalar: Numeric {
exclusive: Bool = false,
reverse: Bool = false
) -> Tensor {
_Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse)
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
return _Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse)
}

/// Returns the cumulative product of this tensor along the specified axis. By default, this
Expand Down Expand Up @@ -2294,6 +2314,7 @@ public extension Tensor where Scalar: Numeric {
/// - exclusive: Indicates whether to perform an exclusive cumulative product.
/// - reverse: Indicates whether to perform the cumulative product in reversed order.
/// - Returns: Result of the cumulative product operation.
/// - Precondition: `axis` must have rank `0`.
/// - Precondition: `axis` must be in the range `-rank..<rank`.
@inlinable
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
Expand All @@ -2302,7 +2323,8 @@ public extension Tensor where Scalar: Numeric {
exclusive: Bool = false,
reverse: Bool = false
) -> Tensor {
_Raw.cumprod(self, axis: axis, exclusive: exclusive, reverse: reverse)
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
return _Raw.cumprod(self, axis: axis, exclusive: exclusive, reverse: reverse)
}
}

Expand Down