-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[DynamicCompilation] Add _AnyTensorHandle and Swift-C round trip entry points #19529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@swift-ci please test tensorflow |
@swift-ci please test tensorflow Linux GPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice patch! Left some comments.
@@ -44,6 +42,20 @@ public final class TensorHandle<Scalar : AccelerableByTensorFlow> { | |||
self.cTensorHandle = cTensorHandle | |||
} | |||
|
|||
deinit { | |||
debugLog("De-initializing TensorHandle.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorHandle -> _AnyTensorHandle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I'm going to move initializers back into TensorHandle<T>
since only that would be correct, and _AnyTensorHandle
is just a type-erased class for easier IRGen. So it's always a TensorHandle.
cHandle = TFE_TensorHandleCopySharingTensor(cHandle, status) | ||
expectEqual(TF_GetCode(status), TF_OK) | ||
TF_DeleteStatus(status) | ||
let anyHandle = _CreateTensorHandleFromC(cHandle, TF_FLOAT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of passing in TF_FLOAT, can we compute that via TFE_TensorHandleDataType() within TFE_TensorHandleCopySharingTensor()? This way within IRGen, we won't need to remember the dtype for each tensor handle. (Otherwise we may want to return that dtype in _GetCTensorHandleFromSwift(), which complicates things and is not necessary for IRGen).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -99,6 +99,11 @@ static bool isUserIgnoredByPartitioning(SILInstruction *inst) { | |||
/// unexpected. | |||
static CanType getSingleElementDeclFieldType(NominalTypeDecl *decl) { | |||
auto *field = tf::getFieldIfContainsSingleField(decl); | |||
if (!field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please extend the API comment to describe/cover this new case.
let status = TF_NewStatus() | ||
// We must do a copy, i.e. a retain on the tensor handle, to make sure it won't | ||
// get double-free'd when both `zero` and `anyHandle` below go out of scope. | ||
cHandle = TFE_TensorHandleCopySharingTensor(cHandle, status) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm this suggests using TensorHandle<...>(owning: cHandle)
within _GetCTensorHandleFromSwift()
is not correct. A c handle should not be owned by 2 or more TensorHandle objects.
In IRgen, for a given tensor handle x
, we may need to call _GetCTensorHandleFromSwift()
on it multiple times, if x
is consumed by multiple tfops. We don't want IRGen code to call TFE_TensorHandleCopySharingTensor
to rebalance the ref count.
Can we call TFE_TensorHandleCopySharingTensor() within _GetCTensorHandleFromSwift()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the corresponding TF call emitted by IRGen consume the retain count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed in person, making _GetCTensorHandleFromSwift
not make a copy is the correct behavior since the lifetime of each graph_op
argument is guaranteed by construction.
This test case is irrelevant because it's making a new TensorHandle
own the extracted CTensorHandle
, which is for testing purpose and will not be emitted by IRGen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you for driving this Richard!
@swift-ci please test tensorflow |
Introduce a scalar-type-erased
_AnyTensorHandle
as a base class forTensorHandle<T>
, along with two runtime entry points that assistgraph_op
lowering in IRGen._swift_tfc_GetCTensorHandleFromSwift
_swift_tfc_CreateTensorHandleFromC
It is unfortunate that this function has to switch on all supported dtypes, but that's the only way to create a concrete
TensorHandle<T>
type. Theswitch
indumpTensorContents
should be unified with this.