Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add two new requirements to TensorArrayProtocol #165

Merged
merged 6 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions Sources/TensorFlow/Core/TensorGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ public protocol TensorArrayProtocol {

var _tensorHandleCount: Int32 { get }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be changed to Int now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.

var _typeList: [TensorDataType] { get }
var _tensorHandles: [_AnyTensorHandle] { get }

init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle
}

/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
Expand Down Expand Up @@ -88,13 +90,21 @@ extension TensorHandle: TensorGroup {
return [Scalar.tensorFlowDataType]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(_owning: tensorHandles!.pointee)
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

precondition(_handles.count == 1)
self.init(handle: _handles[_handles.startIndex])
}
}

extension ResourceHandle: TensorGroup {
Expand All @@ -108,13 +118,21 @@ extension ResourceHandle: TensorGroup {
return [TensorDataType(TF_RESOURCE)]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(owning: tensorHandles!.pointee)
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

precondition(_handles.count == 1)
self.init(handle: _handles[_handles.startIndex])
}
}

extension VariantHandle: TensorGroup {
Expand All @@ -128,13 +146,21 @@ extension VariantHandle: TensorGroup {
return [TensorDataType(TF_VARIANT)]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(owning: tensorHandles!.pointee)
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: _handles[_handles.startIndex])
}
}

extension Tensor: TensorGroup {
Expand All @@ -152,9 +178,17 @@ extension Tensor: TensorGroup {
address!.initialize(to: handle._cTensorHandle)
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

precondition(_handles.count == 1)
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
}
}

extension _TensorElementLiteral: TensorGroup {
Expand All @@ -168,13 +202,21 @@ extension _TensorElementLiteral: TensorGroup {
return [Scalar.tensorFlowDataType]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: handle._cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

precondition(_handles.count == 1)
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
}
}

extension StringTensor: TensorGroup {
Expand All @@ -192,9 +234,17 @@ extension StringTensor: TensorGroup {
address!.initialize(to: handle._cTensorHandle)
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

precondition(_handles.count == 1)
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
}
}

extension Array: TensorArrayProtocol where Element: TensorGroup {
Expand All @@ -216,10 +266,31 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
count: Int(count)).joined())
}

public var _tensorHandles: ([_AnyTensorHandle]) {
var result: [_AnyTensorHandle] = []
result.reserveCapacity(Int(self._tensorHandleCount))
for elem in self {
result += elem._tensorHandles
}
return result
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
let size = count / Int(Element._tensorHandleCount)
self = Array((0..<size).map { Element.init(
_owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount)))
})
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

let size = _handles.count / Int(Element._tensorHandleCount)
self = (0..<size).map {
let start = _handles.index(
_handles.startIndex, offsetBy: $0 * Int(Element._tensorHandleCount))
let end = _handles.index(
start, offsetBy: Int(Element._tensorHandleCount))
return Element.init(_handles: _handles[start..<end])
}
}
}
16 changes: 15 additions & 1 deletion Sources/TensorFlow/Core/TensorHandle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
self.handle = TFETensorHandle(_owning: cTensorHandle)
}

public init(handle: _AnyTensorHandle) {
self.handle = handle
}

@usableFromInline
init(copyingFromCTensor cTensor: CTensor) {
let status = TF_NewStatus()
Expand Down Expand Up @@ -105,7 +109,7 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
extension TensorHandle where Scalar: TensorFlowScalar {
/// Create a `TensorHandle` with a closure that initializes the underlying buffer.
///
/// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a
/// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a
/// tensor with shape `shape`. `scalarsInitializer` must initialize the entire buffer, with
/// contiguous scalars in row-major order.
@inlinable
Expand Down Expand Up @@ -145,6 +149,11 @@ public struct ResourceHandle {
init(owning cTensorHandle: CTensorHandle) {
self.handle = TFETensorHandle(_owning: cTensorHandle)
}

@usableFromInline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop @usableFromInline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

init(handle: _AnyTensorHandle) {
self.handle = handle
}
}

public struct VariantHandle {
Expand All @@ -157,4 +166,9 @@ public struct VariantHandle {
init(owning cTensorHandle: CTensorHandle) {
self.handle = TFETensorHandle(_owning: cTensorHandle)
}

@usableFromInline
init(handle: _AnyTensorHandle) {
self.handle = handle
}
}
13 changes: 13 additions & 0 deletions Sources/TensorFlow/Operators/Dataset.swift
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,19 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
self.first = first
self.second = second
}

public var _tensorHandles: [_AnyTensorHandle] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public var _tensorHandles: [_AnyTensorHandle] {
public var _tensorHandles: [_AnyTensorHandle] {

first._tensorHandles + second._tensorHandles
}

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

let firstStart = _handles.startIndex
let firstEnd = _handles.index(
firstStart, offsetBy: Int(T._tensorHandleCount))
self.first = T.init(_handles: _handles[firstStart..<firstEnd])
self.second = U.init(_handles: _handles[firstEnd..<_handles.endIndex])
}
}

// TODO(SR-9156): This does not work in graph mode.
Expand Down
11 changes: 11 additions & 0 deletions Tests/TensorFlowTests/OperatorTests/DatasetTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ import XCTest
struct SimpleOutput: TensorGroup {
let a: TensorHandle<Int32>
let b: TensorHandle<Int32>

public init<C: RandomAccessCollection>(
_handles: C) where C.Element == _AnyTensorHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_handles: C) where C.Element == _AnyTensorHandle {
_handles: C
) where C.Element == _AnyTensorHandle {

precondition(_handles.count == 2)
let aIndex = _handles.startIndex
let bIndex = _handles.index(aIndex, offsetBy: 1)
a = TensorHandle<Int32>(handle: _handles[aIndex])
b = TensorHandle<Int32>(handle: _handles[bIndex])
}

public var _tensorHandles: [_AnyTensorHandle] { [a.handle, b.handle] }
}

final class DatasetTests: XCTestCase {
Expand Down
Loading