@@ -140,6 +140,8 @@ public extension Tensor {
140
140
@inlinable
141
141
@differentiable ( wrt: self , vjp: _vjpScalarized where Scalar: TensorFlowFloatingPoint)
142
142
func scalarized( ) -> Scalar {
143
+ precondition ( shape. contiguousSize == 1 ,
144
+ " This tensor must have exactly one scalar but contains \( shape. contiguousSize) . " )
143
145
return reshaped ( to: [ ] ) . scalar!
144
146
}
145
147
}
@@ -226,9 +228,14 @@ public extension Tensor {
226
228
/// - Parameters:
227
229
/// - shape: The shape of the tensor.
228
230
/// - scalars: The scalar contents of the tensor.
229
- /// - Precondition: The number of scalars must equal the product of the dimensions of the shape .
231
+ /// - Precondition: The product of the dimensions of the shape must equal the number of scalars .
230
232
@inlinable
231
233
init ( shape: TensorShape , scalars: [ Scalar ] ) {
234
+ precondition ( shape. contiguousSize == scalars. count,
235
+ """
236
+ The shape requires \( shape. contiguousSize) scalars but \( scalars. count) were \
237
+ provided.
238
+ """ )
232
239
self = scalars. withUnsafeBufferPointer { bufferPointer in
233
240
Tensor ( shape: shape, scalars: bufferPointer)
234
241
}
@@ -239,11 +246,14 @@ public extension Tensor {
239
246
/// - Parameters:
240
247
/// - shape: The shape of the tensor.
241
248
/// - scalars: The scalar contents of the tensor.
242
- /// - Precondition: The number of scalars must equal the product of the
243
- /// dimensions of the shape.
249
+ /// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
244
250
@inlinable
245
251
init ( shape: TensorShape , scalars: UnsafeBufferPointer < Scalar > ) {
246
- precondition ( scalars. count == shape. contiguousSize)
252
+ precondition ( shape. contiguousSize == scalars. count,
253
+ """
254
+ The shape requires \( shape. contiguousSize) scalars but \( scalars. count) were \
255
+ provided.
256
+ """ )
247
257
let handle = TensorHandle < Scalar > (
248
258
shape: shape. dimensions,
249
259
scalarsInitializer: { address in
@@ -257,11 +267,14 @@ public extension Tensor {
257
267
/// - Parameters:
258
268
/// - shape: The shape of the tensor.
259
269
/// - scalars: The scalar contents of the tensor.
260
- /// - Precondition: The number of scalars must equal the product of the
261
- /// dimensions of the shape.
270
+ /// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
262
271
@inlinable
263
272
init < C: RandomAccessCollection > ( shape: TensorShape , scalars: C ) where C. Element == Scalar {
264
- precondition ( scalars. count == shape. contiguousSize)
273
+ precondition ( shape. contiguousSize == scalars. count,
274
+ """
275
+ The shape requires \( shape. contiguousSize) scalars but \( scalars. count) were \
276
+ provided.
277
+ """ )
265
278
let handle = TensorHandle < Scalar > (
266
279
shape: shape. dimensions,
267
280
scalarsInitializer: { addr in
0 commit comments