Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ba61cbd

Browse files
committed
Make Element requirement in TensorArrayProtocol.init a conformance.
1 parent f16b750 commit ba61cbd

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public protocol TensorArrayProtocol {
3535
var _tensorHandles: [_AnyTensorHandle] { get }
3636

3737
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
38-
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle
38+
init<C: RandomAccessCollection>(_handles: C) where C.Element : _AnyTensorHandle
3939
}
4040

4141
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
@@ -102,7 +102,7 @@ extension TensorHandle: TensorGroup {
102102

103103
public init<C: RandomAccessCollection>(
104104
_handles: C
105-
) where C.Element == _AnyTensorHandle {
105+
) where C.Element : _AnyTensorHandle {
106106
precondition(_handles.count == 1)
107107
self.init(handle: _handles[_handles.startIndex])
108108
}
@@ -131,7 +131,7 @@ extension ResourceHandle: TensorGroup {
131131

132132
public init<C: RandomAccessCollection>(
133133
_handles: C
134-
) where C.Element == _AnyTensorHandle {
134+
) where C.Element : _AnyTensorHandle {
135135
precondition(_handles.count == 1)
136136
self.init(handle: _handles[_handles.startIndex])
137137
}
@@ -160,7 +160,7 @@ extension VariantHandle: TensorGroup {
160160

161161
public init<C: RandomAccessCollection>(
162162
_handles: C
163-
) where C.Element == _AnyTensorHandle {
163+
) where C.Element : _AnyTensorHandle {
164164
precondition(_handles.count == 1)
165165
self.init(handle: _handles[_handles.startIndex])
166166
}
@@ -189,7 +189,7 @@ extension Tensor: TensorGroup {
189189

190190
public init<C: RandomAccessCollection>(
191191
_handles: C
192-
) where C.Element == _AnyTensorHandle {
192+
) where C.Element : _AnyTensorHandle {
193193
precondition(_handles.count == 1)
194194
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
195195
}
@@ -218,7 +218,7 @@ extension _TensorElementLiteral: TensorGroup {
218218

219219
public init<C: RandomAccessCollection>(
220220
_handles: C
221-
) where C.Element == _AnyTensorHandle {
221+
) where C.Element : _AnyTensorHandle {
222222
precondition(_handles.count == 1)
223223
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
224224
}
@@ -247,7 +247,7 @@ extension StringTensor: TensorGroup {
247247

248248
public init<C: RandomAccessCollection>(
249249
_handles: C
250-
) where C.Element == _AnyTensorHandle {
250+
) where C.Element : _AnyTensorHandle {
251251
precondition(_handles.count == 1)
252252
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
253253
}
@@ -290,7 +290,8 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
290290

291291
public init<C: RandomAccessCollection>(
292292
_handles: C
293-
) where C.Element == _AnyTensorHandle {
293+
) where C.Element : _AnyTensorHandle {
294+
294295
let size = _handles.count / Int(Element._tensorHandleCount)
295296
self = (0..<size).map {
296297
let start = _handles.index(

Sources/TensorFlow/Operators/Dataset.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
222222

223223
public init<C: RandomAccessCollection>(
224224
_handles: C
225-
) where C.Element == _AnyTensorHandle {
225+
) where C.Element : _AnyTensorHandle {
226226
let firstStart = _handles.startIndex
227227
let firstEnd = _handles.index(
228228
firstStart, offsetBy: Int(T._tensorHandleCount))

Tests/TensorFlowTests/OperatorTests/DatasetTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct SimpleOutput: TensorGroup {
2121

2222
public init<C: RandomAccessCollection>(
2323
_handles: C
24-
) where C.Element == _AnyTensorHandle {
24+
) where C.Element : _AnyTensorHandle {
2525
precondition(_handles.count == 2)
2626
let aIndex = _handles.startIndex
2727
let bIndex = _handles.index(aIndex, offsetBy: 1)

Tests/TensorFlowTests/TensorGroupTests.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct Empty : TensorGroup {
2626
init() {}
2727
public init<C: RandomAccessCollection>(
2828
_handles: C
29-
) where C.Element == _AnyTensorHandle {}
29+
) where C.Element : _AnyTensorHandle {}
3030
public var _tensorHandles: [_AnyTensorHandle] { [] }
3131
}
3232

@@ -40,7 +40,7 @@ struct Simple : TensorGroup, Equatable {
4040

4141
public init<C: RandomAccessCollection>(
4242
_handles: C
43-
) where C.Element == _AnyTensorHandle {
43+
) where C.Element : _AnyTensorHandle {
4444
precondition(_handles.count == 2)
4545
let wIndex = _handles.startIndex
4646
let bIndex = _handles.index(wIndex, offsetBy: 1)
@@ -64,7 +64,7 @@ struct Mixed : TensorGroup, Equatable {
6464

6565
public init<C: RandomAccessCollection>(
6666
_handles: C
67-
) where C.Element == _AnyTensorHandle {
67+
) where C.Element : _AnyTensorHandle {
6868
precondition(_handles.count == 2)
6969
let floatIndex = _handles.startIndex
7070
let intIndex = _handles.index(floatIndex, offsetBy: 1)
@@ -92,7 +92,7 @@ struct Nested : TensorGroup, Equatable {
9292

9393
public init<C: RandomAccessCollection>(
9494
_handles: C
95-
) where C.Element == _AnyTensorHandle {
95+
) where C.Element : _AnyTensorHandle {
9696
let simpleStart = _handles.startIndex
9797
let simpleEnd = _handles.index(
9898
simpleStart, offsetBy: Int(Simple._tensorHandleCount))
@@ -116,7 +116,7 @@ struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorG
116116

117117
public init<C: RandomAccessCollection>(
118118
_handles: C
119-
) where C.Element == _AnyTensorHandle {
119+
) where C.Element : _AnyTensorHandle {
120120
let tStart = _handles.startIndex
121121
let tEnd = _handles.index(tStart, offsetBy: Int(T._tensorHandleCount))
122122
t = T.init(_handles: _handles[tStart..<tEnd])
@@ -140,7 +140,7 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>
140140

141141
public init<C: RandomAccessCollection>(
142142
_handles: C
143-
) where C.Element == _AnyTensorHandle {
143+
) where C.Element : _AnyTensorHandle {
144144
let firstStart = _handles.startIndex
145145
let firstEnd = _handles.index(
146146
firstStart, offsetBy: Int(Generic<T,V>._tensorHandleCount))
@@ -153,7 +153,7 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>
153153
}
154154
}
155155

156-
func copy<T>(of handle: TensorHandle<T>) -> _AnyTensorHandle {
156+
func copy<T>(of handle: TensorHandle<T>) -> TFETensorHandle {
157157
let status = TF_NewStatus()
158158
let result = TFETensorHandle(_owning: TFE_TensorHandleCopySharingTensor(
159159
handle._cTensorHandle, status)!)

0 commit comments

Comments
 (0)