Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ec83fa7

Browse files
jon-towsaeta
authored andcommitted
Add precondition message for Tensor initialization (#301)
1 parent d1decbe commit ec83fa7

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ public extension Tensor {
140140
@inlinable
141141
@differentiable(wrt: self, vjp: _vjpScalarized where Scalar: TensorFlowFloatingPoint)
142142
func scalarized() -> Scalar {
143+
precondition(shape.contiguousSize == 1,
144+
"This tensor must have exactly one scalar but contains \(shape.contiguousSize).")
143145
return reshaped(to: []).scalar!
144146
}
145147
}
@@ -226,9 +228,14 @@ public extension Tensor {
226228
/// - Parameters:
227229
/// - shape: The shape of the tensor.
228230
/// - scalars: The scalar contents of the tensor.
229-
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
231+
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
230232
@inlinable
231233
init(shape: TensorShape, scalars: [Scalar]) {
234+
precondition(shape.contiguousSize == scalars.count,
235+
"""
236+
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
237+
provided.
238+
""")
232239
self = scalars.withUnsafeBufferPointer { bufferPointer in
233240
Tensor(shape: shape, scalars: bufferPointer)
234241
}
@@ -239,11 +246,14 @@ public extension Tensor {
239246
/// - Parameters:
240247
/// - shape: The shape of the tensor.
241248
/// - scalars: The scalar contents of the tensor.
242-
/// - Precondition: The number of scalars must equal the product of the
243-
/// dimensions of the shape.
249+
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
244250
@inlinable
245251
init(shape: TensorShape, scalars: UnsafeBufferPointer<Scalar>) {
246-
precondition(scalars.count == shape.contiguousSize)
252+
precondition(shape.contiguousSize == scalars.count,
253+
"""
254+
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
255+
provided.
256+
""")
247257
let handle = TensorHandle<Scalar>(
248258
shape: shape.dimensions,
249259
scalarsInitializer: { address in
@@ -257,11 +267,14 @@ public extension Tensor {
257267
/// - Parameters:
258268
/// - shape: The shape of the tensor.
259269
/// - scalars: The scalar contents of the tensor.
260-
/// - Precondition: The number of scalars must equal the product of the
261-
/// dimensions of the shape.
270+
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
262271
@inlinable
263272
init<C: RandomAccessCollection>(shape: TensorShape, scalars: C) where C.Element == Scalar {
264-
precondition(scalars.count == shape.contiguousSize)
273+
precondition(shape.contiguousSize == scalars.count,
274+
"""
275+
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
276+
provided.
277+
""")
265278
let handle = TensorHandle<Scalar>(
266279
shape: shape.dimensions,
267280
scalarsInitializer: { addr in

0 commit comments

Comments
 (0)