Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 687f5cc

Browse files
committed
Make Element requirement in TensorArrayProtocol.init a conformance.
1 parent 738b7a5 commit 687f5cc

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public protocol TensorArrayProtocol {
3535
var _tensorHandles: [_AnyTensorHandle] { get }
3636

3737
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
38-
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle
38+
init<C: RandomAccessCollection>(_handles: C) where C.Element : _AnyTensorHandle
3939
}
4040

4141
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
@@ -101,7 +101,8 @@ extension TensorHandle: TensorGroup {
101101
}
102102

103103
public init<C: RandomAccessCollection>(
104-
_handles: C) where C.Element == _AnyTensorHandle {
104+
_handles: C
105+
) where C.Element : _AnyTensorHandle {
105106
precondition(_handles.count == 1)
106107
self.init(handle: _handles[_handles.startIndex])
107108
}
@@ -129,7 +130,8 @@ extension ResourceHandle: TensorGroup {
129130
}
130131

131132
public init<C: RandomAccessCollection>(
132-
_handles: C) where C.Element == _AnyTensorHandle {
133+
_handles: C
134+
) where C.Element : _AnyTensorHandle {
133135
precondition(_handles.count == 1)
134136
self.init(handle: _handles[_handles.startIndex])
135137
}
@@ -157,7 +159,8 @@ extension VariantHandle: TensorGroup {
157159
}
158160

159161
public init<C: RandomAccessCollection>(
160-
_handles: C) where C.Element == _AnyTensorHandle {
162+
_handles: C
163+
) where C.Element : _AnyTensorHandle {
161164
precondition(_handles.count == 1)
162165
self.init(handle: _handles[_handles.startIndex])
163166
}
@@ -185,7 +188,8 @@ extension Tensor: TensorGroup {
185188
}
186189

187190
public init<C: RandomAccessCollection>(
188-
_handles: C) where C.Element == _AnyTensorHandle {
191+
_handles: C
192+
) where C.Element : _AnyTensorHandle {
189193
precondition(_handles.count == 1)
190194
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
191195
}
@@ -213,7 +217,8 @@ extension _TensorElementLiteral: TensorGroup {
213217
}
214218

215219
public init<C: RandomAccessCollection>(
216-
_handles: C) where C.Element == _AnyTensorHandle {
220+
_handles: C
221+
) where C.Element : _AnyTensorHandle {
217222
precondition(_handles.count == 1)
218223
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
219224
}
@@ -241,7 +246,8 @@ extension StringTensor: TensorGroup {
241246
}
242247

243248
public init<C: RandomAccessCollection>(
244-
_handles: C) where C.Element == _AnyTensorHandle {
249+
_handles: C
250+
) where C.Element : _AnyTensorHandle {
245251
precondition(_handles.count == 1)
246252
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
247253
}
@@ -283,7 +289,8 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
283289
}
284290

285291
public init<C: RandomAccessCollection>(
286-
_handles: C) where C.Element == _AnyTensorHandle {
292+
_handles: C
293+
) where C.Element : _AnyTensorHandle {
287294
let size = _handles.count / Int(Element._tensorHandleCount)
288295
self = (0..<size).map {
289296
let start = _handles.index(

Sources/TensorFlow/Operators/Dataset.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
221221
}
222222

223223
public init<C: RandomAccessCollection>(
224-
_handles: C) where C.Element == _AnyTensorHandle {
224+
_handles: C) where C.Element : _AnyTensorHandle {
225225
let firstStart = _handles.startIndex
226226
let firstEnd = _handles.index(
227227
firstStart, offsetBy: Int(T._tensorHandleCount))

Tests/TensorFlowTests/OperatorTests/DatasetTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct SimpleOutput: TensorGroup {
2020
let b: TensorHandle<Int32>
2121

2222
public init<C: RandomAccessCollection>(
23-
_handles: C) where C.Element == _AnyTensorHandle {
23+
_handles: C) where C.Element : _AnyTensorHandle {
2424
precondition(_handles.count == 2)
2525
let aIndex = _handles.startIndex
2626
let bIndex = _handles.index(aIndex, offsetBy: 1)

Tests/TensorFlowTests/TensorGroupTests.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ extension TensorDataType : Equatable {
2525
struct Empty : TensorGroup {
2626
init() {}
2727
public init<C: RandomAccessCollection>(
28-
_handles: C) where C.Element == _AnyTensorHandle {}
28+
_handles: C) where C.Element : _AnyTensorHandle {}
2929
public var _tensorHandles: [_AnyTensorHandle] { [] }
3030
}
3131

@@ -38,7 +38,7 @@ struct Simple : TensorGroup, Equatable {
3838
}
3939

4040
public init<C: RandomAccessCollection>(
41-
_handles: C) where C.Element == _AnyTensorHandle {
41+
_handles: C) where C.Element : _AnyTensorHandle {
4242
precondition(_handles.count == 2)
4343
let wIndex = _handles.startIndex
4444
let bIndex = _handles.index(wIndex, offsetBy: 1)
@@ -61,7 +61,7 @@ struct Mixed : TensorGroup, Equatable {
6161
}
6262

6363
public init<C: RandomAccessCollection>(
64-
_handles: C) where C.Element == _AnyTensorHandle {
64+
_handles: C) where C.Element : _AnyTensorHandle {
6565
precondition(_handles.count == 2)
6666
let floatIndex = _handles.startIndex
6767
let intIndex = _handles.index(floatIndex, offsetBy: 1)
@@ -88,7 +88,7 @@ struct Nested : TensorGroup, Equatable {
8888
}
8989

9090
public init<C: RandomAccessCollection>(
91-
_handles: C) where C.Element == _AnyTensorHandle {
91+
_handles: C) where C.Element : _AnyTensorHandle {
9292
let simpleStart = _handles.startIndex
9393
let simpleEnd = _handles.index(
9494
simpleStart, offsetBy: Int(Simple._tensorHandleCount))
@@ -111,7 +111,7 @@ struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorG
111111
}
112112

113113
public init<C: RandomAccessCollection>(
114-
_handles: C) where C.Element == _AnyTensorHandle {
114+
_handles: C) where C.Element : _AnyTensorHandle {
115115
let tStart = _handles.startIndex
116116
let tEnd = _handles.index(tStart, offsetBy: Int(T._tensorHandleCount))
117117
t = T.init(_handles: _handles[tStart..<tEnd])
@@ -134,7 +134,7 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>
134134
}
135135

136136
public init<C: RandomAccessCollection>(
137-
_handles: C) where C.Element == _AnyTensorHandle {
137+
_handles: C) where C.Element : _AnyTensorHandle {
138138
let firstStart = _handles.startIndex
139139
let firstEnd = _handles.index(
140140
firstStart, offsetBy: Int(Generic<T,V>._tensorHandleCount))
@@ -147,7 +147,7 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>
147147
}
148148
}
149149

150-
func copy<T>(of handle: TensorHandle<T>) -> _AnyTensorHandle {
150+
func copy<T>(of handle: TensorHandle<T>) -> TFETensorHandle {
151151
let status = TF_NewStatus()
152152
let result = TFETensorHandle(_owning: TFE_TensorHandleCopySharingTensor(
153153
handle._cTensorHandle, status)!)

0 commit comments

Comments
 (0)