Skip to content

[DynamicCompilation] Add _AnyTensorHandle and Swift-C round trip entry points #19529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 27, 2018

Conversation

rxwei
Copy link
Contributor

@rxwei rxwei commented Sep 25, 2018

Introduce a scalar-type-erased _AnyTensorHandle as a base class for TensorHandle<T>, along with two runtime entry points that assist graph_op lowering in IRGen.

  1. _swift_tfc_GetCTensorHandleFromSwift

    @_silgen_name("_swift_tfc_GetCTensorHandleFromSwift")
    public func _GetCTensorHandleFromSwift(_: _AnyTensorHandle) -> CTensorHandle
  2. _swift_tfc_CreateTensorHandleFromC

    @_silgen_name("_swift_tfc_CreateTensorHandleFromC")
    public func _CreateTensorHandleFromC(_: CTensorHandle, _: TF_DataType) -> _AnyTensorHandle

    It is unfortunate that this function has to switch on all supported dtypes, but that's the only way to create a concrete TensorHandle<T> type. The switch in dumpTensorContents should be unified with this.

@rxwei rxwei added the tensorflow This is for "tensorflow" branch PRs. label Sep 25, 2018
@rxwei rxwei requested review from mhong and lattner September 25, 2018 06:25
@rxwei
Copy link
Contributor Author

rxwei commented Sep 25, 2018

@swift-ci please test tensorflow

@rxwei
Copy link
Contributor Author

rxwei commented Sep 25, 2018

@swift-ci please test tensorflow Linux GPU

Copy link

@mhong mhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice patch! Left some comments.

@@ -44,6 +42,20 @@ public final class TensorHandle<Scalar : AccelerableByTensorFlow> {
self.cTensorHandle = cTensorHandle
}

deinit {
debugLog("De-initializing TensorHandle.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorHandle -> _AnyTensorHandle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm going to move initializers back into TensorHandle<T> since only that would be correct, and _AnyTensorHandle is just a type-erased class for easier IRGen. So it's always a TensorHandle.

cHandle = TFE_TensorHandleCopySharingTensor(cHandle, status)
expectEqual(TF_GetCode(status), TF_OK)
TF_DeleteStatus(status)
let anyHandle = _CreateTensorHandleFromC(cHandle, TF_FLOAT)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of passing in TF_FLOAT, can we compute that via TFE_TensorHandleDataType() within TFE_TensorHandleCopySharingTensor()? This way within IRGen, we won't need to remember the dtype for each tensor handle. (Otherwise we may want to return that dtype in _GetCTensorHandleFromSwift(), which complicates things and is not necessary for IRGen).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -99,6 +99,11 @@ static bool isUserIgnoredByPartitioning(SILInstruction *inst) {
/// unexpected.
static CanType getSingleElementDeclFieldType(NominalTypeDecl *decl) {
auto *field = tf::getFieldIfContainsSingleField(decl);
if (!field)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please extend the API comment to describe/cover this new case.

let status = TF_NewStatus()
// We must do a copy, i.e. a retain on the tensor handle, to make sure it won't
// get double-free'd when both `zero` and `anyHandle` below go out of scope.
cHandle = TFE_TensorHandleCopySharingTensor(cHandle, status)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this suggests using TensorHandle<...>(owning: cHandle) within _GetCTensorHandleFromSwift() is not correct. A c handle should not be owned by 2 or more TensorHandle objects.

In IRgen, for a given tensor handle x, we may need to call _GetCTensorHandleFromSwift() on it multiple times, if x is consumed by multiple tfops. We don't want IRGen code to call TFE_TensorHandleCopySharingTensor to rebalance the ref count.

Can we call TFE_TensorHandleCopySharingTensor() within _GetCTensorHandleFromSwift()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the corresponding TF call emitted by IRGen consume the retain count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed in person, making _GetCTensorHandleFromSwift not make a copy is the correct behavior since the lifetime of each graph_op argument is guaranteed by construction.

This test case is irrelevant because it's making a new TensorHandle own the extracted CTensorHandle, which is for testing purpose and will not be emitted by IRGen.

Copy link
Contributor

@lattner lattner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you for driving this Richard!

@rxwei
Copy link
Contributor Author

rxwei commented Sep 27, 2018

@swift-ci please test tensorflow

2 similar comments
@rxwei
Copy link
Contributor Author

rxwei commented Sep 27, 2018

@swift-ci please test tensorflow

@rxwei
Copy link
Contributor Author

rxwei commented Sep 27, 2018

@swift-ci please test tensorflow

@rxwei rxwei merged commit d5308bd into swiftlang:tensorflow Sep 27, 2018
@rxwei rxwei deleted the runtime-entry-point branch September 27, 2018 05:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants