Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Make Element requirement in TensorArrayProtocol.init a conformance. #171

Merged
merged 3 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions Sources/TensorFlow/Core/TensorGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public protocol TensorArrayProtocol {
var _tensorHandles: [_AnyTensorHandle] { get }

init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle
init<C: RandomAccessCollection>(_handles: C) where C.Element: _AnyTensorHandle
}

/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
Expand Down Expand Up @@ -102,7 +102,7 @@ extension TensorHandle: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: _handles[_handles.startIndex])
}
Expand Down Expand Up @@ -131,7 +131,7 @@ extension ResourceHandle: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: _handles[_handles.startIndex])
}
Expand Down Expand Up @@ -160,7 +160,7 @@ extension VariantHandle: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: _handles[_handles.startIndex])
}
Expand Down Expand Up @@ -189,7 +189,7 @@ extension Tensor: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
}
Expand Down Expand Up @@ -218,7 +218,7 @@ extension _TensorElementLiteral: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
}
Expand Down Expand Up @@ -247,7 +247,7 @@ extension StringTensor: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 1)
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
}
Expand Down Expand Up @@ -290,7 +290,8 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {

let size = _handles.count / Int(Element._tensorHandleCount)
self = (0..<size).map {
let start = _handles.index(
Expand Down
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Operators/Dataset.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
let firstStart = _handles.startIndex
let firstEnd = _handles.index(
firstStart, offsetBy: Int(T._tensorHandleCount))
Expand Down
2 changes: 1 addition & 1 deletion Tests/TensorFlowTests/OperatorTests/DatasetTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct SimpleOutput: TensorGroup {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 2)
let aIndex = _handles.startIndex
let bIndex = _handles.index(aIndex, offsetBy: 1)
Expand Down
14 changes: 7 additions & 7 deletions Tests/TensorFlowTests/TensorGroupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct Empty : TensorGroup {
init() {}
public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {}
) where C.Element: _AnyTensorHandle {}
public var _tensorHandles: [_AnyTensorHandle] { [] }
}

Expand All @@ -40,7 +40,7 @@ struct Simple : TensorGroup, Equatable {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 2)
let wIndex = _handles.startIndex
let bIndex = _handles.index(wIndex, offsetBy: 1)
Expand All @@ -64,7 +64,7 @@ struct Mixed : TensorGroup, Equatable {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
precondition(_handles.count == 2)
let floatIndex = _handles.startIndex
let intIndex = _handles.index(floatIndex, offsetBy: 1)
Expand Down Expand Up @@ -92,7 +92,7 @@ struct Nested : TensorGroup, Equatable {

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
let simpleStart = _handles.startIndex
let simpleEnd = _handles.index(
simpleStart, offsetBy: Int(Simple._tensorHandleCount))
Expand All @@ -116,7 +116,7 @@ struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorG

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
let tStart = _handles.startIndex
let tEnd = _handles.index(tStart, offsetBy: Int(T._tensorHandleCount))
t = T.init(_handles: _handles[tStart..<tEnd])
Expand All @@ -140,7 +140,7 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>

public init<C: RandomAccessCollection>(
_handles: C
) where C.Element == _AnyTensorHandle {
) where C.Element: _AnyTensorHandle {
let firstStart = _handles.startIndex
let firstEnd = _handles.index(
firstStart, offsetBy: Int(Generic<T,V>._tensorHandleCount))
Expand All @@ -153,7 +153,7 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>
}
}

func copy<T>(of handle: TensorHandle<T>) -> _AnyTensorHandle {
func copy<T>(of handle: TensorHandle<T>) -> TFETensorHandle {
let status = TF_NewStatus()
let result = TFETensorHandle(_owning: TFE_TensorHandleCopySharingTensor(
handle._cTensorHandle, status)!)
Expand Down