-
Notifications
You must be signed in to change notification settings - Fork 137
Add two new requirements to TensorArrayProtocol #165
Changes from all commits
99b8d03
2eeb847
96b0f5a
e0423fe
5a4d789
75c98e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol { | |||||||
|
||||||||
var _tensorHandleCount: Int32 { get } | ||||||||
var _typeList: [TensorDataType] { get } | ||||||||
var _tensorHandles: [_AnyTensorHandle] { get } | ||||||||
|
||||||||
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) | ||||||||
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle | ||||||||
} | ||||||||
|
||||||||
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`. | ||||||||
|
@@ -88,13 +90,21 @@ extension TensorHandle: TensorGroup { | |||||||
return [Scalar.tensorFlowDataType] | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } | ||||||||
|
||||||||
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
address!.initialize(to: _cTensorHandle) | ||||||||
} | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
self.init(_owning: tensorHandles!.pointee) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
precondition(_handles.count == 1) | ||||||||
self.init(handle: _handles[_handles.startIndex]) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
extension ResourceHandle: TensorGroup { | ||||||||
|
@@ -108,13 +118,21 @@ extension ResourceHandle: TensorGroup { | |||||||
return [TensorDataType(TF_RESOURCE)] | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } | ||||||||
|
||||||||
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
address!.initialize(to: _cTensorHandle) | ||||||||
} | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
self.init(owning: tensorHandles!.pointee) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
precondition(_handles.count == 1) | ||||||||
self.init(handle: _handles[_handles.startIndex]) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
extension VariantHandle: TensorGroup { | ||||||||
|
@@ -128,13 +146,21 @@ extension VariantHandle: TensorGroup { | |||||||
return [TensorDataType(TF_VARIANT)] | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } | ||||||||
|
||||||||
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
address!.initialize(to: _cTensorHandle) | ||||||||
} | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
self.init(owning: tensorHandles!.pointee) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
precondition(_handles.count == 1) | ||||||||
self.init(handle: _handles[_handles.startIndex]) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
extension Tensor: TensorGroup { | ||||||||
|
@@ -152,9 +178,17 @@ extension Tensor: TensorGroup { | |||||||
address!.initialize(to: handle._cTensorHandle) | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
precondition(_handles.count == 1) | ||||||||
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
extension _TensorElementLiteral: TensorGroup { | ||||||||
|
@@ -168,13 +202,21 @@ extension _TensorElementLiteral: TensorGroup { | |||||||
return [Scalar.tensorFlowDataType] | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } | ||||||||
|
||||||||
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
address!.initialize(to: handle._cTensorHandle) | ||||||||
} | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
precondition(_handles.count == 1) | ||||||||
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
extension StringTensor: TensorGroup { | ||||||||
|
@@ -192,9 +234,17 @@ extension StringTensor: TensorGroup { | |||||||
address!.initialize(to: handle._cTensorHandle) | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
precondition(_handles.count == 1) | ||||||||
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
extension Array: TensorArrayProtocol where Element: TensorGroup { | ||||||||
|
@@ -216,10 +266,31 @@ extension Array: TensorArrayProtocol where Element: TensorGroup { | |||||||
count: Int(count)).joined()) | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: ([_AnyTensorHandle]) { | ||||||||
var result: [_AnyTensorHandle] = [] | ||||||||
result.reserveCapacity(Int(self._tensorHandleCount)) | ||||||||
for elem in self { | ||||||||
result += elem._tensorHandles | ||||||||
} | ||||||||
return result | ||||||||
} | ||||||||
|
||||||||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) { | ||||||||
let size = count / Int(Element._tensorHandleCount) | ||||||||
self = Array((0..<size).map { Element.init( | ||||||||
_owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount))) | ||||||||
}) | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
let size = _handles.count / Int(Element._tensorHandleCount) | ||||||||
self = (0..<size).map { | ||||||||
let start = _handles.index( | ||||||||
_handles.startIndex, offsetBy: $0 * Int(Element._tensorHandleCount)) | ||||||||
let end = _handles.index( | ||||||||
start, offsetBy: Int(Element._tensorHandleCount)) | ||||||||
return Element.init(_handles: _handles[start..<end]) | ||||||||
} | ||||||||
} | ||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible { | |
self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
} | ||
|
||
public init(handle: _AnyTensorHandle) { | ||
self.handle = handle | ||
} | ||
|
||
@usableFromInline | ||
init(copyingFromCTensor cTensor: CTensor) { | ||
let status = TF_NewStatus() | ||
|
@@ -105,7 +109,7 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible { | |
extension TensorHandle where Scalar: TensorFlowScalar { | ||
/// Create a `TensorHandle` with a closure that initializes the underlying buffer. | ||
/// | ||
/// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a | ||
/// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a | ||
/// tensor with shape `shape`. `scalarsInitializer` must initialize the entire buffer, with | ||
/// contiguous scalars in row-major order. | ||
@inlinable | ||
|
@@ -145,6 +149,11 @@ public struct ResourceHandle { | |
init(owning cTensorHandle: CTensorHandle) { | ||
self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
} | ||
|
||
@usableFromInline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
init(handle: _AnyTensorHandle) { | ||
self.handle = handle | ||
} | ||
} | ||
|
||
public struct VariantHandle { | ||
|
@@ -157,4 +166,9 @@ public struct VariantHandle { | |
init(owning cTensorHandle: CTensorHandle) { | ||
self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
} | ||
|
||
@usableFromInline | ||
init(handle: _AnyTensorHandle) { | ||
self.handle = handle | ||
} | ||
} |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -215,6 +215,19 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup { | |||||||
self.first = first | ||||||||
self.second = second | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
first._tensorHandles + second._tensorHandles | ||||||||
} | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
let firstStart = _handles.startIndex | ||||||||
let firstEnd = _handles.index( | ||||||||
firstStart, offsetBy: Int(T._tensorHandleCount)) | ||||||||
self.first = T.init(_handles: _handles[firstStart..<firstEnd]) | ||||||||
self.second = U.init(_handles: _handles[firstEnd..<_handles.endIndex]) | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
// TODO(SR-9156): This does not work in graph mode. | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -18,6 +18,17 @@ import XCTest | |||||||
struct SimpleOutput: TensorGroup { | ||||||||
let a: TensorHandle<Int32> | ||||||||
let b: TensorHandle<Int32> | ||||||||
|
||||||||
public init<C: RandomAccessCollection>( | ||||||||
_handles: C) where C.Element == _AnyTensorHandle { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
precondition(_handles.count == 2) | ||||||||
let aIndex = _handles.startIndex | ||||||||
let bIndex = _handles.index(aIndex, offsetBy: 1) | ||||||||
a = TensorHandle<Int32>(handle: _handles[aIndex]) | ||||||||
b = TensorHandle<Int32>(handle: _handles[bIndex]) | ||||||||
} | ||||||||
|
||||||||
public var _tensorHandles: [_AnyTensorHandle] { [a.handle, b.handle] } | ||||||||
} | ||||||||
|
||||||||
final class DatasetTests: XCTestCase { | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be changed to
Int
now?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.