Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add precondition message for Tensor initialization #301

Merged
merged 8 commits into from
Jun 27, 2019
27 changes: 20 additions & 7 deletions Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ public extension Tensor {
@inlinable
@differentiable(wrt: self, vjp: _vjpScalarized where Scalar: TensorFlowFloatingPoint)
func scalarized() -> Scalar {
precondition(shape.contiguousSize == 1,
"This tensor must have exactly one scalar but contains \(shape.contiguousSize).")
return reshaped(to: []).scalar!
}
}
Expand Down Expand Up @@ -226,9 +228,14 @@ public extension Tensor {
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
@inlinable
init(shape: TensorShape, scalars: [Scalar]) {
precondition(shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
self = scalars.withUnsafeBufferPointer { bufferPointer in
Tensor(shape: shape, scalars: bufferPointer)
}
Expand All @@ -239,11 +246,14 @@ public extension Tensor {
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The number of scalars must equal the product of the
/// dimensions of the shape.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
@inlinable
init(shape: TensorShape, scalars: UnsafeBufferPointer<Scalar>) {
precondition(scalars.count == shape.contiguousSize)
precondition(shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
let handle = TensorHandle<Scalar>(
shape: shape.dimensions,
scalarsInitializer: { address in
Expand All @@ -257,11 +267,14 @@ public extension Tensor {
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The number of scalars must equal the product of the
/// dimensions of the shape.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
@inlinable
init<C: RandomAccessCollection>(shape: TensorShape, scalars: C) where C.Element == Scalar {
precondition(scalars.count == shape.contiguousSize)
precondition(shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
let handle = TensorHandle<Scalar>(
shape: shape.dimensions,
scalarsInitializer: { addr in
Expand Down