Skip to content

Commit c09ccb3

Browse files
author
Mingsheng Hong
committed
Extended graph_op IRGen to support calling C APIs directly.
In this PR, we call the TF C API TFE_RunConstOp() instead of the previous compiler runtime entry point @_silgen_name("_swift_tfc_RunEagerConstTest"), and then call compiler runtime entry point @_silgen_name("_swift_tfc_CreateFloatTensorHandleFromCTensorHandle")to wrap it into a TensorHandle<Float>. The main changes are: 1. To create an llvm::Function object based on a function name such as "_swift_tfc_GetGlobalEagerContext", we first call silModule.findFunction() to get the SILFunction, and then call IGM.getAddrOfSILFunction() to get the llvm::Function. To make this work, we fixed a bug in SILDeserializer::readSILFunctionChecked(), which pervious a function decl (not body) from being deserialized in the IRGen stage. 2. We obtain from llvm::Function objects the LLVM type objects for C data types TFE_Context* and TFE_TensorHandle*, and then generate bitcast instructions to put the function params into the right types, before issuing the relevant function call. Next steps: 1. Replace the TFE_RunConstOp() call with a sequence of TF eager C API calls. 2. Generalize the graph_op decoding logic to handle graph_op's other than Const. 3. Support generic tf datatype T instead of the hard-coded Float. 4. Figure out a way to call do scalar promotion even in the case of -Onone, since Tensor<Float>(1.0) becomes a pseudo graph_op "tfc.scalarToTensor", which gets should be turn into a "Const" graph_op.
1 parent c12db28 commit c09ccb3

File tree

3 files changed

+80
-16
lines changed

3 files changed

+80
-16
lines changed

lib/IRGen/IRGenSIL.cpp

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
#include "swift/Basic/STLExtras.h"
3535
#include "swift/SIL/Dominance.h"
3636
#include "swift/SIL/InstructionUtils.h"
37+
// SWIFT_ENABLE_TENSORFLOW
38+
#include "swift/SIL/GraphOperationInfo.h"
3739
#include "swift/SIL/PrettyStackTrace.h"
40+
// SWIFT_ENABLE_TENSORFLOW
41+
#include "swift/SIL/SILConstants.h"
3842
#include "swift/SIL/SILDebugScope.h"
3943
#include "swift/SIL/SILDeclRef.h"
4044
#include "swift/SIL/SILLinkage.h"
@@ -81,6 +85,7 @@
8185

8286
using namespace swift;
8387
using namespace irgen;
88+
using swift::tf::GraphOperationInfo;
8489

8590
// FIXME: Remove this option entirely and turn this on by default.
8691
llvm::cl::opt<bool> DebugInfoInlinedGenerics(
@@ -1875,9 +1880,6 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
18751880
if (!llvm::TFDynamicCompilation)
18761881
llvm_unreachable("graph_op is not valid in canonical SIL");
18771882

1878-
auto &llvmModule = IGM.Module;
1879-
auto &llvmContext = llvmModule.getContext();
1880-
18811883
tf::GraphOperationInfo opInfo(i);
18821884
SmallVector<tf::GraphOperationInfo::StructuredOperand, 4> structuredOperands;
18831885
auto opName = opInfo.decodeName(structuredOperands);
@@ -1895,20 +1897,64 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
18951897
// 2. Run the graph_op
18961898
// 3. Set the output tensor handles via setLoweredExplosion()
18971899

1900+
auto &silModule = CurSILFn->getModule();
1901+
18981902
// The true return type is TFE_Context*, which is an opaque pointer, so it
1899-
// maps to void* in the Swift-C calling convention.
1900-
auto getContextFn = llvmModule.getOrInsertFunction(
1901-
"_swift_tfc_GetGlobalEagerContext",
1902-
llvm::TypeBuilder<void *(), /*cross_compilable=*/false>::get(
1903-
llvmContext));
1903+
// maps to void* in the Swift-C calling convention. `eagerContext` has type
1904+
// void*, or i8* in LLVM type system.
1905+
auto getContextSilFn = silModule.findFunction(
1906+
"_swift_tfc_GetGlobalEagerContext", SILLinkage::PublicExternal);
1907+
assert(getContextSilFn);
1908+
llvm::Constant *getContextFn =
1909+
IGM.getAddrOfSILFunction(getContextSilFn, NotForDefinition);
1910+
assert(getContextFn);
19041911
auto eagerContext = Builder.CreateCall(getContextFn, {});
19051912

1906-
// The true function type is TFE_Context* -> TFE_TensorHandle*.
1907-
auto testFunc = llvmModule.getOrInsertFunction(
1908-
"_swift_tfc_RunEagerConstTest",
1909-
llvm::TypeBuilder<void *(void *), false>::get(llvmContext));
1910-
auto tensorHandle = Builder.CreateCall(testFunc, {eagerContext});
1911-
1913+
// For now we call a hard-coded C API to run a const op:
1914+
// TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx)
1915+
// TODO: Remove this hard-coded C API call.
1916+
LLVM_DEBUG(llvm::dbgs() << "IRGen for TFE_RunConstOp().\n");
1917+
auto TFERunConstSilFn =
1918+
silModule.findFunction("TFE_RunConstOp", SILLinkage::PublicExternal);
1919+
assert(TFERunConstSilFn);
1920+
llvm::Function *TFERunConstFn =
1921+
IGM.getAddrOfSILFunction(TFERunConstSilFn, NotForDefinition);
1922+
assert(TFERunConstFn);
1923+
1924+
// We need to cast `eagerContext` of type i8* to %struct.TFE_Context*
1925+
auto *funcTy = TFERunConstFn->getFunctionType();
1926+
assert(funcTy->getNumParams() == 1);
1927+
auto *tfeContextTy = funcTy->getParamType(0);
1928+
LLVM_DEBUG(llvm::dbgs() << " Param 0 of TFE_RunConstOp() has type "
1929+
<< *tfeContextTy << ".\n");
1930+
auto eagerContextTyped = Builder.CreateBitCast(eagerContext, tfeContextTy);
1931+
1932+
LLVM_DEBUG(llvm::dbgs() << " Creating call over TFE_RunConstOp().\n");
1933+
auto cTensorHandle = Builder.CreateCall(TFERunConstFn, {eagerContextTyped});
1934+
1935+
// Wrap `cTensorHandle` into a TensorHandle<T> object.
1936+
// This requires casting `cTensorHandle` of i8* type to
1937+
// %struct.TFE_TensorHandle*.
1938+
LLVM_DEBUG(llvm::dbgs() << "IRGen for creating result TensorHandle.\n");
1939+
auto createHandleSilFn = silModule.findFunction(
1940+
"_swift_tfc_CreateFloatTensorHandleFromCTensorHandle",
1941+
SILLinkage::PublicExternal);
1942+
assert(createHandleSilFn);
1943+
llvm::Function *createHandleFn =
1944+
IGM.getAddrOfSILFunction(createHandleSilFn, NotForDefinition);
1945+
assert(createHandleFn);
1946+
auto *createHandleFnTy = createHandleFn->getFunctionType();
1947+
assert(createHandleFnTy->getNumParams() == 1);
1948+
auto *cTensorHandleTy = createHandleFnTy->getParamType(0);
1949+
LLVM_DEBUG(llvm::dbgs() << " Param 0 of tensor handle creation fn has type "
1950+
<< *cTensorHandleTy << ".\n");
1951+
auto cTensorHandleTyped =
1952+
Builder.CreateBitCast(cTensorHandle, cTensorHandleTy);
1953+
LLVM_DEBUG(llvm::dbgs() << " Creating call over tensor handle creation.\n");
1954+
auto tensorHandle = Builder.CreateCall(createHandleFn, {cTensorHandleTyped});
1955+
1956+
LLVM_DEBUG(
1957+
llvm::dbgs() << "Done with IRGen for graph_op; setting explosion.\n");
19121958
Explosion e;
19131959
e.add(tensorHandle);
19141960

lib/Serialization/DeserializeSIL.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
427427
break;
428428

429429
case SILStage::Lowered:
430-
llvm_unreachable("cannot deserialize into a module that has entered "
431-
"Lowered stage");
430+
if (!declarationOnly) // SWIFT_ENABLE_TENSORFLOW
431+
llvm_unreachable("cannot deserialize into a module that has entered "
432+
"Lowered stage");
432433
}
433434

434435
if (FID == 0)

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,3 +1166,20 @@ public func _GetGlobalEagerContext() -> CTFEContext {
11661166
return _ExecutionContext.global.eagerContext
11671167
}
11681168

1169+
// TODO: replace these functions with the generic versions commented out below.
1170+
@inlinable
1171+
@_silgen_name("_swift_tfc_ExtractFloatCTensorHandle")
1172+
public func _ExtractCTensorHandle(
1173+
_ handle: TensorHandle<Float>
1174+
) -> CTensorHandle {
1175+
return handle.cTensorHandle
1176+
}
1177+
1178+
@inlinable
1179+
@_silgen_name("_swift_tfc_CreateFloatTensorHandleFromCTensorHandle")
1180+
public func _CreateTensorHandleFromCTensorHandle(
1181+
_ ownedCHandle: CTensorHandle
1182+
) -> TensorHandle<Float> {
1183+
return TensorHandle<Float>(owning: ownedCHandle)
1184+
}
1185+

0 commit comments

Comments
 (0)