Skip to content

Extended graph_op IRGen to support calling C APIs directly. #19524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
#include "swift/Basic/STLExtras.h"
#include "swift/SIL/Dominance.h"
#include "swift/SIL/InstructionUtils.h"
// SWIFT_ENABLE_TENSORFLOW
#include "swift/SIL/GraphOperationInfo.h"
#include "swift/SIL/PrettyStackTrace.h"
// SWIFT_ENABLE_TENSORFLOW
#include "swift/SIL/SILConstants.h"
#include "swift/SIL/SILDebugScope.h"
#include "swift/SIL/SILDeclRef.h"
#include "swift/SIL/SILLinkage.h"
Expand Down Expand Up @@ -81,6 +85,7 @@

using namespace swift;
using namespace irgen;
using swift::tf::GraphOperationInfo;

// FIXME: Remove this option entirely and turn this on by default.
llvm::cl::opt<bool> DebugInfoInlinedGenerics(
Expand Down Expand Up @@ -1916,9 +1921,6 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
return;
}

auto &llvmModule = IGM.Module;
auto &llvmContext = llvmModule.getContext();

// TODO: Remove these. They are a temporary way of testing that dynamic
// attributes make it here.
LLVM_DEBUG(llvm::dbgs() << "IRGen for graph_op: "
Expand All @@ -1934,20 +1936,64 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
// 2. Run the graph_op
// 3. Set the output tensor handles via setLoweredExplosion()

auto &silModule = CurSILFn->getModule();

// The true return type is TFE_Context*, which is an opaque pointer, so it
// maps to void* in the Swift-C calling convention.
auto getContextFn = llvmModule.getOrInsertFunction(
"_swift_tfc_GetGlobalEagerContext",
llvm::TypeBuilder<void *(), /*cross_compilable=*/false>::get(
llvmContext));
// maps to void* in the Swift-C calling convention. `eagerContext` has type
// void*, or i8* in LLVM type system.
auto getContextSilFn = silModule.findFunction(
"_swift_tfc_GetGlobalEagerContext", SILLinkage::PublicExternal);
assert(getContextSilFn);
llvm::Constant *getContextFn =
IGM.getAddrOfSILFunction(getContextSilFn, NotForDefinition);
assert(getContextFn);
auto eagerContext = Builder.CreateCall(getContextFn, {});

// The true function type is TFE_Context* -> TFE_TensorHandle*.
auto testFunc = llvmModule.getOrInsertFunction(
"_swift_tfc_RunEagerConstTest",
llvm::TypeBuilder<void *(void *), false>::get(llvmContext));
auto tensorHandle = Builder.CreateCall(testFunc, {eagerContext});

// For now we call a hard-coded C API to run a const op:
// TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx)
// TODO: Remove this hard-coded C API call.
LLVM_DEBUG(llvm::dbgs() << "IRGen for TFE_RunConstOp().\n");
auto TFERunConstSilFn =
silModule.findFunction("TFE_RunConstOp", SILLinkage::PublicExternal);
assert(TFERunConstSilFn);
llvm::Function *TFERunConstFn =
IGM.getAddrOfSILFunction(TFERunConstSilFn, NotForDefinition);
assert(TFERunConstFn);

// We need to cast `eagerContext` of type i8* to %struct.TFE_Context*
auto *funcTy = TFERunConstFn->getFunctionType();
assert(funcTy->getNumParams() == 1);
auto *tfeContextTy = funcTy->getParamType(0);
LLVM_DEBUG(llvm::dbgs() << " Param 0 of TFE_RunConstOp() has type "
<< *tfeContextTy << ".\n");
auto eagerContextTyped = Builder.CreateBitCast(eagerContext, tfeContextTy);

LLVM_DEBUG(llvm::dbgs() << " Creating call over TFE_RunConstOp().\n");
auto cTensorHandle = Builder.CreateCall(TFERunConstFn, {eagerContextTyped});

// Wrap `cTensorHandle` into a TensorHandle<T> object.
// This requires casting `cTensorHandle` of i8* type to
// %struct.TFE_TensorHandle*.
LLVM_DEBUG(llvm::dbgs() << "IRGen for creating result TensorHandle.\n");
auto createHandleSilFn = silModule.findFunction(
"_swift_tfc_CreateFloatTensorHandleFromCTensorHandle",
SILLinkage::PublicExternal);
assert(createHandleSilFn);
llvm::Function *createHandleFn =
IGM.getAddrOfSILFunction(createHandleSilFn, NotForDefinition);
assert(createHandleFn);
auto *createHandleFnTy = createHandleFn->getFunctionType();
assert(createHandleFnTy->getNumParams() == 1);
auto *cTensorHandleTy = createHandleFnTy->getParamType(0);
LLVM_DEBUG(llvm::dbgs() << " Param 0 of tensor handle creation fn has type "
<< *cTensorHandleTy << ".\n");
auto cTensorHandleTyped =
Builder.CreateBitCast(cTensorHandle, cTensorHandleTy);
LLVM_DEBUG(llvm::dbgs() << " Creating call over tensor handle creation.\n");
auto tensorHandle = Builder.CreateCall(createHandleFn, {cTensorHandleTyped});

LLVM_DEBUG(
llvm::dbgs() << "Done with IRGen for graph_op; setting explosion.\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all LGTM

Explosion e;
e.add(tensorHandle);

Expand Down
5 changes: 3 additions & 2 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
break;

case SILStage::Lowered:
llvm_unreachable("cannot deserialize into a module that has entered "
"Lowered stage");
if (!declarationOnly) // SWIFT_ENABLE_TENSORFLOW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super obvious to me this is a bug. Which line crashed when you didn't change this? Is it SILModule::findFunction?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. The crash stack trace is

(gdb) bt
#0  SignalHandler (Sig=6)
    at /usr/local/google/home/hongm/ssd_part/git/swift-base/llvm/lib/Support/Unix/Signals.inc:314
#1  <signal handler called>
#2  0x00007fffe39fcfcf in raise () from /lib/x86_64-linux-gnu/libc.so.6
#3  0x00007fffe39fe3fa in abort () from /lib/x86_64-linux-gnu/libc.so.6
#4  0x0000000004c6083b in llvm::llvm_unreachable_internal (msg=<optimized out>, 
    file=0x555206e "/usr/local/google/home/hongm/ssd_part/git/swift-base/swift/lib/Serialization/DeserializeSIL.cpp", line=431)
    at /usr/local/google/home/hongm/ssd_part/git/swift-base/llvm/lib/Support/ErrorHandling.cpp:189
#5  0x0000000002218aeb in swift::SILDeserializer::readSILFunctionChecked (this=0x8de05b0, FID=..., 
    existingFn=0x0, name="_swift_tfc_StartTensorComputation", declarationOnly=true, 
    errorIfEmptyBody=true)
    at /usr/local/google/home/hongm/ssd_part/git/swift-base/swift/lib/Serialization/DeserializeSIL.cpp:430
#6  0x0000000002230346 in swift::SILDeserializer::lookupSILFunction (this=0x8de05b0, 
    name="_swift_tfc_StartTensorComputation", declarationOnly=true)
    at /usr/local/google/home/hongm/ssd_part/git/swift-base/swift/lib/Serialization/DeserializeSIL.cpp:2561
#7  0x000000000218d7b8 in swift::SerializedSILLoader::lookupSILFunction (this=0x8c89630, 
    Name="_swift_tfc_StartTensorComputation", declarationOnly=true, 
    Linkage=llvm::Optional is initialized = {...})
    at /usr/local/google/home/hongm/ssd_part/git/swift-base/swift/lib/Serialization/SerializedSILLoader.cpp:67

SerializedSILLoader::lookupSILFunction() is called by SILModule::findFunction()

I think it's a bug because the comment above says:

We can't deserialize function bodies after ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is expected behavior since it's in IRGen, which turns the SIL stage to Lowered. Maybe SILModule::findFunction should not be used, because it'll try to deserialize a function when it's not in the current module.

Have you tried creating the declaration of _swift_tfc_GetGlobalEagerContext directly? This seems to be the standard approach in IRGen. See how IRGen creates the declaration of the runtime entry point for lowering retain_value_addr.
https://github.com/apple/swift/blob/c12db28b164e1207e918f6bd3f7119f80f0cdaa1/lib/IRGen/Outlining.cpp#L299

You could define a IRGenModule::getOrCreateTFCGetGlobalEagerContextFunction (same for other entry points), and getTypeAndGenericSignatureForManglingOutlineFunction seems to be able to handle mangling for you. Right, it's a lot of code defining a getter for each entry point, but it seems to be the standard way in the compiler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Richard suggests, I think this is expected behavior in IRGen.

As an alternative to creating declarations explicitly, can we remember the relevant SILFunctions somewhere in the context of GraphOp and use them in getAddrOfSILFunction?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers. I'd like to avoid explicit declarations if at all possible -- that would increase code complexity, and also maintenance overhead (when the Swift function signature of _swift_tfc_GetGlobalEagerContext is changed, we need to change the irgen compiler code accordingly as well).

Also, while we only have ~5 compiler runtime entry points for now, we'll be calling ~20 TF eager C APIs (see TF repo tensorflow/tensorflow/c/eager/c_api.h), and creating explicit decls for them could involve a lot more code complexity and maintenance overhead.

The referenced code for getOrCreateRetainFunction() (I also studied swift/include/swift/Runtime/RuntimeFunctions.def) is of a different scenario -- these runtime functions do not have their impls already defined as external C APIs or swift functions (via silgen names).

I thought deserializing the SIL function decl (not body) in IRGen phase would still be safe. Do you have a specific concern?

@bgogul, it'd be possible to cache the SILFunction* at an earlier compiler phase (before we get into IRGen), but if the cached version is still valid to used, I believe the deserialized version (just the decl) would be safe to use as well, and the latter would be somewhat simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one concern with deserializing during IRGen phase is that we don't have a chance to run some of the passes that would have run if we deserialized that function at an early stage. Even though you are only using it here for getting the declaration, it goes against separation of concerns. Perhaps others can jump in if this is not a major concern.

Isn't the case that we check that TensorFlow module is imported at some point? Given that these functions are part of TensorFlow, why can't we cache them then and make sure they are valid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we need to codegen any TensorFlow C API from IRGen directly? Can they be moved and concentrated to one or two entry points in the runtime?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. I'll submit this patch, and start a discussion with core team on this "SILFunction lookup" based approach.

llvm_unreachable("cannot deserialize into a module that has entered "
"Lowered stage");
}

if (FID == 0)
Expand Down
17 changes: 17 additions & 0 deletions stdlib/public/TensorFlow/CompilerRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1166,3 +1166,20 @@ public func _GetGlobalEagerContext() -> CTFEContext {
return _ExecutionContext.global.eagerContext
}

// TODO: replace these functions with generic ones that do not hard-code Float.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be more straight-forward to do when there is a non-generic superclass of TensorHandle.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Will refresh this once I use Richard's patch.

@inlinable
@_silgen_name("_swift_tfc_ExtractFloatCTensorHandle")
public func _ExtractCTensorHandle(
_ handle: TensorHandle<Float>
) -> CTensorHandle {
return handle.cTensorHandle
}

@inlinable
@_silgen_name("_swift_tfc_CreateFloatTensorHandleFromCTensorHandle")
public func _CreateTensorHandleFromCTensorHandle(
_ ownedCHandle: CTensorHandle
) -> TensorHandle<Float> {
return TensorHandle<Float>(owning: ownedCHandle)
}