Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 738b7a5

Browse files
authored
Add two new requirements to TensorGroup and TensorArrayProtocol (#165)
1 parent d7eff12 commit 738b7a5

File tree

5 files changed

+250
-92
lines changed

5 files changed

+250
-92
lines changed

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol {
3232

3333
var _tensorHandleCount: Int32 { get }
3434
var _typeList: [TensorDataType] { get }
35+
var _tensorHandles: [_AnyTensorHandle] { get }
3536

3637
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
38+
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle
3739
}
3840

3941
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
@@ -88,13 +90,21 @@ extension TensorHandle: TensorGroup {
8890
return [Scalar.tensorFlowDataType]
8991
}
9092

93+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }
94+
9195
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
9296
address!.initialize(to: _cTensorHandle)
9397
}
9498

9599
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
96100
self.init(_owning: tensorHandles!.pointee)
97101
}
102+
103+
public init<C: RandomAccessCollection>(
104+
_handles: C) where C.Element == _AnyTensorHandle {
105+
precondition(_handles.count == 1)
106+
self.init(handle: _handles[_handles.startIndex])
107+
}
98108
}
99109

100110
extension ResourceHandle: TensorGroup {
@@ -108,13 +118,21 @@ extension ResourceHandle: TensorGroup {
108118
return [TensorDataType(TF_RESOURCE)]
109119
}
110120

121+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }
122+
111123
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
112124
address!.initialize(to: _cTensorHandle)
113125
}
114126

115127
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
116128
self.init(owning: tensorHandles!.pointee)
117129
}
130+
131+
public init<C: RandomAccessCollection>(
132+
_handles: C) where C.Element == _AnyTensorHandle {
133+
precondition(_handles.count == 1)
134+
self.init(handle: _handles[_handles.startIndex])
135+
}
118136
}
119137

120138
extension VariantHandle: TensorGroup {
@@ -128,13 +146,21 @@ extension VariantHandle: TensorGroup {
128146
return [TensorDataType(TF_VARIANT)]
129147
}
130148

149+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }
150+
131151
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
132152
address!.initialize(to: _cTensorHandle)
133153
}
134154

135155
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
136156
self.init(owning: tensorHandles!.pointee)
137157
}
158+
159+
public init<C: RandomAccessCollection>(
160+
_handles: C) where C.Element == _AnyTensorHandle {
161+
precondition(_handles.count == 1)
162+
self.init(handle: _handles[_handles.startIndex])
163+
}
138164
}
139165

140166
extension Tensor: TensorGroup {
@@ -152,9 +178,17 @@ extension Tensor: TensorGroup {
152178
address!.initialize(to: handle._cTensorHandle)
153179
}
154180

181+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
182+
155183
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
156184
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
157185
}
186+
187+
public init<C: RandomAccessCollection>(
188+
_handles: C) where C.Element == _AnyTensorHandle {
189+
precondition(_handles.count == 1)
190+
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
191+
}
158192
}
159193

160194
extension _TensorElementLiteral: TensorGroup {
@@ -168,13 +202,21 @@ extension _TensorElementLiteral: TensorGroup {
168202
return [Scalar.tensorFlowDataType]
169203
}
170204

205+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
206+
171207
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
172208
address!.initialize(to: handle._cTensorHandle)
173209
}
174210

175211
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
176212
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
177213
}
214+
215+
public init<C: RandomAccessCollection>(
216+
_handles: C) where C.Element == _AnyTensorHandle {
217+
precondition(_handles.count == 1)
218+
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
219+
}
178220
}
179221

180222
extension StringTensor: TensorGroup {
@@ -192,9 +234,17 @@ extension StringTensor: TensorGroup {
192234
address!.initialize(to: handle._cTensorHandle)
193235
}
194236

237+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
238+
195239
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
196240
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
197241
}
242+
243+
public init<C: RandomAccessCollection>(
244+
_handles: C) where C.Element == _AnyTensorHandle {
245+
precondition(_handles.count == 1)
246+
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
247+
}
198248
}
199249

200250
extension Array: TensorArrayProtocol where Element: TensorGroup {
@@ -216,10 +266,31 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
216266
count: Int(count)).joined())
217267
}
218268

269+
public var _tensorHandles: ([_AnyTensorHandle]) {
270+
var result: [_AnyTensorHandle] = []
271+
result.reserveCapacity(Int(self._tensorHandleCount))
272+
for elem in self {
273+
result += elem._tensorHandles
274+
}
275+
return result
276+
}
277+
219278
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
220279
let size = count / Int(Element._tensorHandleCount)
221280
self = Array((0..<size).map { Element.init(
222281
_owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount)))
223282
})
224283
}
284+
285+
public init<C: RandomAccessCollection>(
286+
_handles: C) where C.Element == _AnyTensorHandle {
287+
let size = _handles.count / Int(Element._tensorHandleCount)
288+
self = (0..<size).map {
289+
let start = _handles.index(
290+
_handles.startIndex, offsetBy: $0 * Int(Element._tensorHandleCount))
291+
let end = _handles.index(
292+
start, offsetBy: Int(Element._tensorHandleCount))
293+
return Element.init(_handles: _handles[start..<end])
294+
}
295+
}
225296
}

Sources/TensorFlow/Core/TensorHandle.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
6262
self.handle = TFETensorHandle(_owning: cTensorHandle)
6363
}
6464

65+
public init(handle: _AnyTensorHandle) {
66+
self.handle = handle
67+
}
68+
6569
@usableFromInline
6670
init(copyingFromCTensor cTensor: CTensor) {
6771
let status = TF_NewStatus()
@@ -145,6 +149,11 @@ public struct ResourceHandle {
145149
init(owning cTensorHandle: CTensorHandle) {
146150
self.handle = TFETensorHandle(_owning: cTensorHandle)
147151
}
152+
153+
@usableFromInline
154+
init(handle: _AnyTensorHandle) {
155+
self.handle = handle
156+
}
148157
}
149158

150159
public struct VariantHandle {
@@ -157,4 +166,9 @@ public struct VariantHandle {
157166
init(owning cTensorHandle: CTensorHandle) {
158167
self.handle = TFETensorHandle(_owning: cTensorHandle)
159168
}
169+
170+
@usableFromInline
171+
init(handle: _AnyTensorHandle) {
172+
self.handle = handle
173+
}
160174
}

Sources/TensorFlow/Operators/Dataset.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,19 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
215215
self.first = first
216216
self.second = second
217217
}
218+
219+
public var _tensorHandles: [_AnyTensorHandle] {
220+
first._tensorHandles + second._tensorHandles
221+
}
222+
223+
public init<C: RandomAccessCollection>(
224+
_handles: C) where C.Element == _AnyTensorHandle {
225+
let firstStart = _handles.startIndex
226+
let firstEnd = _handles.index(
227+
firstStart, offsetBy: Int(T._tensorHandleCount))
228+
self.first = T.init(_handles: _handles[firstStart..<firstEnd])
229+
self.second = U.init(_handles: _handles[firstEnd..<_handles.endIndex])
230+
}
218231
}
219232

220233
// TODO(SR-9156): This does not work in graph mode.

Tests/TensorFlowTests/OperatorTests/DatasetTests.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ import XCTest
1818
struct SimpleOutput: TensorGroup {
1919
let a: TensorHandle<Int32>
2020
let b: TensorHandle<Int32>
21+
22+
public init<C: RandomAccessCollection>(
23+
_handles: C) where C.Element == _AnyTensorHandle {
24+
precondition(_handles.count == 2)
25+
let aIndex = _handles.startIndex
26+
let bIndex = _handles.index(aIndex, offsetBy: 1)
27+
a = TensorHandle<Int32>(handle: _handles[aIndex])
28+
b = TensorHandle<Int32>(handle: _handles[bIndex])
29+
}
30+
31+
public var _tensorHandles: [_AnyTensorHandle] { [a.handle, b.handle] }
2132
}
2233

2334
final class DatasetTests: XCTestCase {

0 commit comments

Comments
 (0)