-
Notifications
You must be signed in to change notification settings - Fork 137
Add two new requirements to TensorArrayProtocol #165
Conversation
@@ -51,6 +53,8 @@ public protocol TensorGroup: TensorArrayProtocol { | |||
/// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors | |||
/// starting at address `tensorHandles`. | |||
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) | |||
|
|||
init(handles: [_AnyTensorHandle]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is redundant because TensorGroup
inherits from TensorArrayProtocol
.
init(handles: [_AnyTensorHandle]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
|
||
struct Nested : TensorGroup, Equatable { | ||
// Immutable. | ||
let simple: Simple | ||
// Mutable. | ||
var mixed: Mixed | ||
|
||
init(simple: Simple, mixed: Mixed) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This initializer shouldn't be necessary. There's already an implicitly synthesized memberwise initializer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is what I thought, but I get the following error, when I don't have these:
[1/2] Compiling TensorFlowTests TensorGroupTests.swift
/usr/local/google/home/bgogul/workspace/brain/s4tf/tensorflow-swift-apis/Tests/TensorFlowTests/TensorGroupTests.swift:173:22: error: cannot invoke initializer for type 'Simple' with an argument list of type '(w: Tensor<Float>, b: Tensor<Float>)'
let simple = Simple(w: w, b: b)
^
/usr/local/google/home/bgogul/workspace/brain/s4tf/tensorflow-swift-apis/Tests/TensorFlowTests/TensorGroupTests.swift:173:22: note: overloads for 'Simple' exist with these partially matching parameter lists: (_handles: C), (_owning: Optional<UnsafePointer<OpaquePointer>>\
)
let simple = Simple(w: w, b: b)
^
...
Is it because of the the protocol initializers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Yeah, that's probably why.
} | ||
|
||
struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorGroup, Equatable { | ||
var t: T | ||
var u: U | ||
|
||
public init(t: T, u: U) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for this memberwise initializer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
|
||
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) | ||
init(handles: [_AnyTensorHandle]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requiring the argument to be an Array
is not ideal, because in a lot of cases we are creating a tensor group from a slice of an array. Converting a slice to an Array
creates an unnecessary copy. Instead, I think this should take a generic RandomAccessCollection
.
Also, for consistency with other requirements, it's better for the first argument label of this initializer to start with an underscore.
init(handles: [_AnyTensorHandle]) | |
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Done.
} | ||
} | ||
|
||
func copyOf<T>(handle: TensorHandle<T>) -> _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func copyOf<T>(handle: TensorHandle<T>) -> _AnyTensorHandle { | |
func copy<T>(of handle: TensorHandle<T>) -> _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Random thoughts (since this is a test, naming doesn't matter at all!):
Swift naming conventions recommend that when a verb/noun phrase that the function represents ("copy of handle") contains a prepositional phrase formed with an argument ("of handle"), make the preposition be an argument label. Also, when an argument label merely repeats the type information of the argument, it ("handle") should be omitted.
What do you think of the following?
func copyOf<T>(handle: TensorHandle<T>) -> _AnyTensorHandle { | |
func copy<T>(of handle: TensorHandle<T>) -> _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol { | |||
|
|||
var _tensorHandleCount: Int32 { get } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be changed to Int
now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.
@@ -145,6 +149,11 @@ public struct ResourceHandle { | |||
init(owning cTensorHandle: CTensorHandle) { | |||
self.handle = TFETensorHandle(_owning: cTensorHandle) | |||
} | |||
|
|||
@usableFromInline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop @usableFromInline
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
public init(handles: [_AnyTensorHandle]) { | ||
let size = handles.count / Int(Element._tensorHandleCount) | ||
self = Array((0..<size).map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The map(_:)
method already produces an array. No need to call the array initializer.
self = Array((0..<size).map { | |
self = (0..<size).map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
public init(handles: [_AnyTensorHandle]) { | ||
let size = handles.count / Int(Element._tensorHandleCount) | ||
self = Array((0..<size).map { | ||
let start = $0 * Int(Element._tensorHandleCount) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix indentation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, @rxwei. I addressed your comments. PTAL.
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol { | |||
|
|||
var _tensorHandleCount: Int32 { get } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.
|
||
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) | ||
init(handles: [_AnyTensorHandle]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Done.
@@ -51,6 +53,8 @@ public protocol TensorGroup: TensorArrayProtocol { | |||
/// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors | |||
/// starting at address `tensorHandles`. | |||
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) | |||
|
|||
init(handles: [_AnyTensorHandle]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
public init(handles: [_AnyTensorHandle]) { | ||
let size = handles.count / Int(Element._tensorHandleCount) | ||
self = Array((0..<size).map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
public init(handles: [_AnyTensorHandle]) { | ||
let size = handles.count / Int(Element._tensorHandleCount) | ||
self = Array((0..<size).map { | ||
let start = $0 * Int(Element._tensorHandleCount) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -145,6 +149,11 @@ public struct ResourceHandle { | |||
init(owning cTensorHandle: CTensorHandle) { | |||
self.handle = TFETensorHandle(_owning: cTensorHandle) | |||
} | |||
|
|||
@usableFromInline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
|
||
struct Nested : TensorGroup, Equatable { | ||
// Immutable. | ||
let simple: Simple | ||
// Mutable. | ||
var mixed: Mixed | ||
|
||
init(simple: Simple, mixed: Mixed) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is what I thought, but I get the following error, when I don't have these:
[1/2] Compiling TensorFlowTests TensorGroupTests.swift
/usr/local/google/home/bgogul/workspace/brain/s4tf/tensorflow-swift-apis/Tests/TensorFlowTests/TensorGroupTests.swift:173:22: error: cannot invoke initializer for type 'Simple' with an argument list of type '(w: Tensor<Float>, b: Tensor<Float>)'
let simple = Simple(w: w, b: b)
^
/usr/local/google/home/bgogul/workspace/brain/s4tf/tensorflow-swift-apis/Tests/TensorFlowTests/TensorGroupTests.swift:173:22: note: overloads for 'Simple' exist with these partially matching parameter lists: (_handles: C), (_owning: Optional<UnsafePointer<OpaquePointer>>\
)
let simple = Simple(w: w, b: b)
^
...
Is it because of the the protocol initializers?
} | ||
|
||
struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorGroup, Equatable { | ||
var t: T | ||
var u: U | ||
|
||
public init(t: T, u: U) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
} | ||
} | ||
|
||
func copyOf<T>(handle: TensorHandle<T>) -> _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great!
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: _cTensorHandle) | ||
} | ||
|
||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||
self.init(owning: tensorHandles!.pointee) | ||
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: _cTensorHandle) | ||
} | ||
|
||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||
self.init(_owning: tensorHandles!.pointee) | ||
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: handle._cTensorHandle) | ||
} | ||
|
||
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
} | ||
|
||
public init<C: RandomAccessCollection>( | ||
_handles: C) where C.Element == _AnyTensorHandle { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_handles: C) where C.Element == _AnyTensorHandle { | |
_handles: C | |
) where C.Element == _AnyTensorHandle { |
Oops, just noticed these messages. Here is a PR to fix the style issues: #170 |
`TensorArrayProtocol` and `TensorGroup` have been [moved](tensorflow/swift-apis#139) to [tensorflow/swift-apis](https://github.com/tensorflow/swift-apis), and their derived conformances tests have been [copied](tensorflow/swift-apis#158) there as well. [tensorflow/swift-apis#165](tensorflow/swift-apis#165) changed the protocol requirements which made `TensorGroup` conformances no longer entirely derivable, but the derived conformances tests in this repo have not been updated, and it's blocking us from updating the checkout hash for tensorflow/swift-apis. This PR removes `TensorGroup` derived conformance tests from apple/swift. We expect future changes to `TensorGroup` derived conformances to thoroughly test against tensorflow/swift-apis.
`TensorArrayProtocol` and `TensorGroup` have been [moved](tensorflow/swift-apis#139) to [tensorflow/swift-apis](https://github.com/tensorflow/swift-apis), and their derived conformances tests have been [copied](tensorflow/swift-apis#158) there as well. [tensorflow/swift-apis#165](tensorflow/swift-apis#165) changed the protocol requirements which made `TensorGroup` conformances no longer entirely derivable, but the derived conformances tests in this repo have not been updated, and it's blocking us from updating the checkout hash for tensorflow/swift-apis. This PR removes `TensorGroup` derived conformance tests from apple/swift. We expect future changes to `TensorGroup` derived conformances to thoroughly test against tensorflow/swift-apis. Unblocks #25235.
This is one of the first steps in addressing https://bugs.swift.org/browse/TF-542. Specifically, this PR introduces two new requirements to
TensorArrayProtocol
andTensorGroup
:init(handles: [_AnyTensorHandle])
var _tensorHandles: [_AnyTensorHandle] { get }
This is also necessary for the implementation of LazyTensor.
The Derived conformances for
TensorGroup
does not address these requirements yet. Therefore, the tests have been changed to conform withTensorGroup
explicitly.