Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b878b67

Browse files
ocamporsaeta
authored andcommitted
Add more preconditions for Basic.swift module (#557)
* Create helper function to check axis precondition in Basic.swift * Add Tensor.split preconditions * Add preconditions for split * Add preconditions for tiled function * Add preconditions for reshaped
1 parent d902cb3 commit b878b67

File tree

1 file changed

+48
-5
lines changed

1 file changed

+48
-5
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ public extension TensorFlowScalar {
3434
}
3535

3636
public extension Tensor {
37+
/// Helper function that assess if `axis` is in the range `[-rank, rank)`, where `rank` is the rank of
38+
/// the provided tensors.
39+
@usableFromInline
40+
internal func preconditionAxis(_ axis: Int) {
41+
precondition(
42+
axis >= -rank && axis < rank,
43+
"""
44+
The axis must be in the range [-rank, rank)
45+
of the provided tensors.
46+
""")
47+
}
48+
3749
/// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors.
3850
/// Unpacks `N` tensors from this tensor by chipping it along the `axis` dimension, where `N`
3951
/// is inferred from this tensor's shape. For example, given a tensor with shape
@@ -59,6 +71,7 @@ public extension Tensor {
5971
@inlinable
6072
@differentiable(vjp: _vjpUnstacked(alongAxis:) where Scalar: TensorFlowFloatingPoint)
6173
func unstacked(alongAxis axis: Int = 0) -> [Tensor] {
74+
preconditionAxis(axis)
6275
let posAxis = axis < 0 ? axis + rank : axis
6376
return _Raw.unpack(value: self, num: Int64(shape[posAxis]), axis: Int64(posAxis))
6477
}
@@ -88,7 +101,11 @@ public extension Tensor {
88101
@inlinable
89102
@differentiable(vjp: _vjpSplit(count:alongAxis:) where Scalar: TensorFlowFloatingPoint)
90103
func split(count: Int, alongAxis axis: Int = 0) -> [Tensor] {
91-
_Raw.split(splitDim: Tensor<Int32>(Int32(axis)), value: self, numSplit: Int64(count))
104+
preconditionAxis(axis)
105+
precondition(
106+
shapeTensor[axis].scalarized() % Int32(count) == 0,
107+
"Number of ways to split should evenly divide the split dimension.")
108+
return _Raw.split(splitDim: Tensor<Int32>(Int32(axis)), value: self, numSplit: Int64(count))
92109
}
93110

94111
/// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces.
@@ -119,7 +136,11 @@ public extension Tensor {
119136
wrt: self,
120137
vjp: _vjpSplit(sizes:alongAxis:) where Scalar: TensorFlowFloatingPoint)
121138
func split(sizes: Tensor<Int32>, alongAxis axis: Int = 0) -> [Tensor] {
122-
_Raw.splitV(
139+
preconditionAxis(axis)
140+
precondition(
141+
shapeTensor[axis] == sizes.sum(),
142+
"The values in sizes must add up to the size of dimension axis.")
143+
return _Raw.splitV(
123144
value: self,
124145
sizeSplits: sizes,
125146
splitDim: Tensor<Int32>(Int32(axis)),
@@ -133,11 +154,16 @@ public extension Tensor {
133154
/// values of this tensor are replicated `multiples[i]` times along the `i`'th dimension. For
134155
/// example, tiling `[a b c d]` by `[2]` produces `[a b c d a b c d]`.
135156
///
157+
/// - Precondition: The expected `rank` of multiples must be `1`.
136158
/// - Precondition: The shape of `multiples` must be `[tensor.rank]`.
137159
@inlinable
138160
@differentiable(wrt: self, vjp: _vjpTiled(multiples:) where Scalar: TensorFlowFloatingPoint)
139161
func tiled(multiples: Tensor<Int32>) -> Tensor {
140-
_Raw.tile(self, multiples: multiples)
162+
precondition(multiples.rank == 1, "The expected rank of multiples must be 1.")
163+
precondition(
164+
rank == multiples.shapeTensor.scalarized(),
165+
"The shape of multiples must be [tensor.rank].")
166+
return _Raw.tile(self, multiples: multiples)
141167
}
142168

143169
/// Reshape to the shape of the specified `Tensor`.
@@ -162,7 +188,23 @@ public extension Tensor {
162188
@inlinable
163189
@differentiable(wrt: self, vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint)
164190
func reshaped(toShape newShape: Tensor<Int32>) -> Tensor {
165-
_Raw.reshape(self, shape: newShape)
191+
let totalNegative = newShape.scalars.filter({$0 == -1}).count
192+
let positiveShapeSizes = newShape.scalars.filter({$0 > 0})
193+
let newShapeScalarCount = positiveShapeSizes.reduce(1, {$0 * $1})
194+
195+
precondition(totalNegative <= 1, "Only one input size may be -1.")
196+
197+
if totalNegative == 1 {
198+
precondition(
199+
scalarCount % Int(newShapeScalarCount) == 0,
200+
"The number of scalars must be a multiple of the new shape.")
201+
} else {
202+
precondition(
203+
scalarCount == newShapeScalarCount,
204+
"The number of scalars must match the new shape.")
205+
}
206+
207+
return _Raw.reshape(self, shape: newShape)
166208
}
167209

168210
/// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order.
@@ -408,7 +450,8 @@ public extension Tensor {
408450
atIndices indices: Tensor<Index>,
409451
alongAxis axis: Int = 0
410452
) -> Tensor {
411-
_Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
453+
preconditionAxis(axis)
454+
return _Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
412455
}
413456

414457
/// Returns slices of this tensor at `indices` along the `axis` dimension, while ignoring the

0 commit comments

Comments
 (0)