Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 81fd321

Browse files
committed
Remove TensorProtocol.
1 parent 2b36637 commit 81fd321

File tree

3 files changed

+6
-43
lines changed

3 files changed

+6
-43
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public protocol AnyTensor {
2929
/// The generic parameter `Scalar` describes the type of scalars in the tensor (such as `Int32`,
3030
/// `Float`, etc).
3131
@frozen
32-
public struct Tensor<Scalar: TensorFlowScalar>: TensorProtocol {
32+
public struct Tensor<Scalar: TensorFlowScalar> {
3333
/// The underlying `TensorHandle`.
3434
/// - Note: `handle` is public to allow user defined ops, but should not normally be used.
3535
public let handle: TensorHandle<Scalar>
@@ -348,18 +348,8 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
348348
/// during the conversion from an array literal to a `Tensor`, and is purely
349349
/// for implementation purposes.
350350
@frozen
351-
public struct _TensorElementLiteral<Scalar>: TensorProtocol where Scalar: TensorFlowScalar {
351+
public struct _TensorElementLiteral<Scalar> where Scalar: TensorFlowScalar {
352352
@usableFromInline let tensor: Tensor<Scalar>
353-
354-
@inlinable
355-
public var handle: TensorHandle<Scalar> {
356-
return tensor.handle
357-
}
358-
359-
@inlinable
360-
public init(handle: TensorHandle<Scalar>) {
361-
tensor = Tensor(handle: handle)
362-
}
363353
}
364354

365355
extension _TensorElementLiteral: ExpressibleByBooleanLiteral

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,21 +240,20 @@ extension _TensorElementLiteral: TensorGroup {
240240
return [Scalar.tensorFlowDataType]
241241
}
242242

243-
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
243+
public var _tensorHandles: [_AnyTensorHandle] { tensor._tensorHandles }
244244

245245
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
246-
address!.initialize(to: handle._cTensorHandle)
246+
tensor._unpackTensorHandles(into: address)
247247
}
248248

249249
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
250-
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
250+
tensor = Tensor(_owning: tensorHandles)
251251
}
252252

253253
public init<C: RandomAccessCollection>(
254254
_handles: C
255255
) where C.Element: _AnyTensorHandle {
256-
precondition(_handles.count == 1)
257-
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
256+
tensor = Tensor(_handles: _handles)
258257
}
259258
}
260259

Sources/TensorFlow/Core/TensorProtocol.swift

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)