Skip to content

Commit 1c5ea9d

Browse files
author
Mingsheng Hong
authored
Wrapped TF C APIs as runtime functions, so that IRGen can generate calls to them (#19555)
* Wrapped TF C APIs as runtime functions, so that IRGen can generate calls to them. This patch "disolved" the call to the packaged experimental C API TF_RunConstOp() into a set of finer-grained C API calls. An alternative to the runtime function approach is to use Clang importer to call the C APIs. Clang importer allows the C APIs to become available as SIL functions with the same names, and thus can be found and called in IRGen via silModule.findFunction() (see https://github.com/apple/swift/blob/ad7def2c6bffd95d62e5e665c9faed0f8dac49f5/lib/IRGen/IRGenSIL.cpp#L1956-L1957). The concerns with this approach are: 1. There are no precedents in IRGen that calls C APIs via this clang importer route. Having IRGen depend on clang importer may not be desirable. 2. Some C APIs cannot be found via clang importer (e.g. TF_NewStatus() can be found, but TFE_NewOp() cannot). As a work-around, these C APIs must be called in the user module code first, before IRGen over the user module can locate them as SIL functions. This feels brittle. 3. When passing objects between compiler runtime entry points (swift functions) such as @_silgen_name("_swift_tfc_GetGlobalEagerContext") and the C APIs, IRGen have to issue bitcast to convert between a real struct type (e.g. TFE_Context*) and void* (i8* in LLVM type) as used in the entry points. In contrast, the wrapped runtime functions consistently use void* in their interfaces, so that IRGen need not do such bitcasts. One downside with the runtime function approach is that we need to wrap ~20 TF C APIs in the new files TensorFlow.{h,cpp}. This is however mostly a one-time eng cost. * Reverted some compiler rt changes. * Addressed feedback: 1. Removed runtime func impls, when we can us the C API ones. 2. Changed some compiler RT entry points to using @_cdecl, so that they can be called in IRGen directly (instead of going through clang importer and silgen names). * Addressed comments from @rxwei * Added "find_package(TensorFlow REQUIRED)" to swift/stdlib/public/runtime/CMakeLists.txt, because CI run (e.g. https://ci-external.swift.org/view/Pull%20Request/job/swift-PR-TensorFlow-Linux/859/console) reported: ``` CMake Error at stdlib/public/runtime/CMakeLists.txt:30 (include_directories): include_directories given empty-string as include directory. ``` * Moved code in TensorFlow.{h,cpp} to the CTensorFLow module (ctensorflow_init.{h,cpp}), so that the stdlib/public/runtime/ need not depend on the TF library. Also addresed more feedback.
1 parent 835215f commit 1c5ea9d

File tree

7 files changed

+277
-68
lines changed

7 files changed

+277
-68
lines changed

include/swift/Runtime/RuntimeFunctions.def

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,83 @@ FUNCTION(AutoDiffPopFromTape, swift_autodiffPopFromTape, C_CC,
13061306
ARGS(OpaquePtrTy),
13071307
ATTRS(NoUnwind))
13081308

1309+
// SWIFT_ENABLE_TENSORFLOW
1310+
//===----------------------------------------------------------------------===//
1311+
// - MARK: Runtime functions issued by IRGen in dynamic compilation mode.
1312+
// - Naming convention:
1313+
// - TF_XXX functions correspond to TF C APIs
1314+
// - TFE_XXX functions correspond to TF eager C APIs
1315+
// - TFC_XXX functions are C functions defined in the compiler runtime
1316+
//===----------------------------------------------------------------------===//
1317+
1318+
// Function declarations, with the definitions are in the C API impl.
1319+
FUNCTION(TF_NewStatus, TF_NewStatus, C_CC,
1320+
RETURNS(Int8PtrTy),
1321+
ARGS(),
1322+
ATTRS(NoUnwind))
1323+
1324+
FUNCTION(TF_DeleteStatus, TF_DeleteStatus, C_CC,
1325+
RETURNS(),
1326+
ARGS(Int8PtrTy),
1327+
ATTRS(NoUnwind))
1328+
1329+
FUNCTION(TFE_NewOp, TFE_NewOp, C_CC,
1330+
RETURNS(Int8PtrTy),
1331+
ARGS(Int8PtrTy, Int8PtrTy, Int8PtrTy),
1332+
ATTRS(NoUnwind))
1333+
1334+
FUNCTION(TFE_DeleteOp, TFE_DeleteOp, C_CC,
1335+
RETURNS(),
1336+
ARGS(Int8PtrTy),
1337+
ATTRS(NoUnwind))
1338+
1339+
FUNCTION(TFE_OpSetAttrType, TFE_OpSetAttrType, C_CC,
1340+
RETURNS(),
1341+
ARGS(Int8PtrTy, Int8PtrTy, Int32Ty),
1342+
ATTRS(NoUnwind))
1343+
1344+
FUNCTION(TFE_OpSetAttrTensor, TFE_OpSetAttrTensor, C_CC,
1345+
RETURNS(),
1346+
ARGS(Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy),
1347+
ATTRS(NoUnwind))
1348+
1349+
FUNCTION(TF_DeleteTensor, TF_DeleteTensor, C_CC,
1350+
RETURNS(),
1351+
ARGS(Int8PtrTy),
1352+
ATTRS(NoUnwind))
1353+
1354+
// Functions with definitions in runtime library.
1355+
FUNCTION(TFC_CreateScalarFloatTensor, swift_tfc_CreateScalarFloatTensor, C_CC,
1356+
RETURNS(Int8PtrTy),
1357+
ARGS(Int32Ty),
1358+
ATTRS(NoUnwind))
1359+
1360+
FUNCTION(TFE_Execute, swift_tfc_TFE_Execute, C_CC,
1361+
RETURNS(),
1362+
ARGS(Int8PtrTy, Int8PtrPtrTy, Int32PtrTy, Int8PtrTy),
1363+
ATTRS(NoUnwind))
1364+
1365+
FUNCTION(TFC_GetGlobalEagerContext, _swift_tfc_GetGlobalEagerContext, C_CC,
1366+
RETURNS(Int8PtrTy),
1367+
ARGS(),
1368+
ATTRS(NoUnwind))
1369+
1370+
// TODO: enable these decls once we have AnyTensorHandle.
1371+
// FUNCTION(TFC_ExtractFloatCTensorHandle, _swift_tfc_ExtractFloatCTensorHandle, C_CC,
1372+
// RETURNS(Int8PtrTy),
1373+
// ARGS(Int8PtrTy),
1374+
// ATTRS(NoUnwind))
1375+
1376+
// FUNCTION(TFC_CreateFloatTensorHandleFromCTensorHandle, _swift_tfc_CreateFloatTensorHandleFromCTensorHandle, C_CC,
1377+
// RETURNS(Int8PtrTy),
1378+
// ARGS(Int8PtrTy),
1379+
// ATTRS(NoUnwind))
1380+
1381+
FUNCTION(TFC_CheckOk, _swift_tfc_CheckOk, C_CC,
1382+
RETURNS(),
1383+
ARGS(Int8PtrTy),
1384+
ATTRS(NoUnwind))
1385+
13091386
#undef RETURNS
13101387
#undef ARGS
13111388
#undef ATTRS

lib/IRGen/IRGenSIL.cpp

Lines changed: 148 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,23 @@ class IRGenSILFunction :
928928
}
929929
}
930930
}
931-
931+
932+
// SWIFT_ENABLE_TENSORFLOW
933+
// Returns the LLVM function with `funcName`. It must exist.
934+
llvm::Function *findFunction(StringRef funcName, SILModule &silModule) {
935+
LLVM_DEBUG(llvm::dbgs() << "IRGen for calling " << funcName << "().\n");
936+
auto silFn = silModule.findFunction(funcName, SILLinkage::PublicExternal);
937+
assert(silFn);
938+
llvm::Function *fn = IGM.getAddrOfSILFunction(silFn, NotForDefinition);
939+
assert(fn);
940+
return fn;
941+
}
942+
943+
void checkOk(llvm::Value *status) {
944+
auto *checkOkFn = IGM.getTFC_CheckOkFn();
945+
Builder.CreateCall(checkOkFn, {status});
946+
}
947+
932948
//===--------------------------------------------------------------------===//
933949
// SIL instruction lowering
934950
//===--------------------------------------------------------------------===//
@@ -1897,6 +1913,23 @@ static void abortOnGraphOp(IRGenFunction &IGF, llvm::StringRef errMessage) {
18971913
IGF.Builder.CreateCall(abortFunc, {});
18981914
}
18991915

1916+
// Create and return a (i8*-typed) address value to a constant string.
1917+
static llvm::Value *createStringValAddr(IRGenModule &IGM, StringRef strVal) {
1918+
auto &llvmModule = IGM.Module;
1919+
auto &llvmContext = llvmModule.getContext();
1920+
auto opNameVal = llvm::ConstantDataArray::getString(llvmContext, strVal);
1921+
auto global =
1922+
new llvm::GlobalVariable(IGM.Module, opNameVal->getType(), true,
1923+
llvm::GlobalValue::PrivateLinkage, opNameVal);
1924+
1925+
// Make an i8*.
1926+
auto zero = llvm::ConstantInt::get(IGM.Int32Ty, 0);
1927+
llvm::Constant *indices[] = {zero, zero};
1928+
auto opNameValAddr = llvm::ConstantExpr::getInBoundsGetElementPtr(
1929+
global->getValueType(), global, indices);
1930+
return opNameValAddr;
1931+
}
1932+
19001933
/// For now we lower any graph_op inst into a const TF node. (super fast tensor
19011934
/// computation, but those who don't care about correctness. :-) )
19021935
/// TODO: Fix this mis-compilation.
@@ -1938,59 +1971,127 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
19381971

19391972
auto &silModule = CurSILFn->getModule();
19401973

1974+
auto *TFNewStatusFn = IGM.getTF_NewStatusFn();
1975+
auto status = Builder.CreateCall(TFNewStatusFn, {});
1976+
1977+
if (opInfo.getOperationName() != "Const") {
1978+
LLVM_DEBUG(llvm::dbgs()
1979+
<< "Done with IRGen for dummy graph_op; setting explosion.\n");
1980+
// The added explosion is incorrect, but is good enough for unit testing.
1981+
Explosion e;
1982+
e.add(TFNewStatusFn);
1983+
1984+
SILValue result = i->getResults()[0];
1985+
setLoweredExplosion(result, e);
1986+
return;
1987+
}
1988+
19411989
// The true return type is TFE_Context*, which is an opaque pointer, so it
19421990
// maps to void* in the Swift-C calling convention. `eagerContext` has type
19431991
// void*, or i8* in LLVM type system.
1944-
auto getContextSilFn = silModule.findFunction(
1945-
"_swift_tfc_GetGlobalEagerContext", SILLinkage::PublicExternal);
1946-
assert(getContextSilFn);
1947-
llvm::Constant *getContextFn =
1948-
IGM.getAddrOfSILFunction(getContextSilFn, NotForDefinition);
1949-
assert(getContextFn);
1992+
auto *getContextFn = IGM.getTFC_GetGlobalEagerContextFn();
19501993
auto eagerContext = Builder.CreateCall(getContextFn, {});
19511994

1952-
// For now we call a hard-coded C API to run a const op:
1953-
// TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx)
1954-
// TODO: Remove this hard-coded C API call.
1955-
LLVM_DEBUG(llvm::dbgs() << "IRGen for TFE_RunConstOp().\n");
1956-
auto TFERunConstSilFn =
1957-
silModule.findFunction("TFE_RunConstOp", SILLinkage::PublicExternal);
1958-
assert(TFERunConstSilFn);
1959-
llvm::Function *TFERunConstFn =
1960-
IGM.getAddrOfSILFunction(TFERunConstSilFn, NotForDefinition);
1961-
assert(TFERunConstFn);
1962-
1963-
// We need to cast `eagerContext` of type i8* to %struct.TFE_Context*
1964-
auto *funcTy = TFERunConstFn->getFunctionType();
1965-
assert(funcTy->getNumParams() == 1);
1966-
auto *tfeContextTy = funcTy->getParamType(0);
1967-
LLVM_DEBUG(llvm::dbgs() << " Param 0 of TFE_RunConstOp() has type "
1968-
<< *tfeContextTy << ".\n");
1969-
auto eagerContextTyped = Builder.CreateBitCast(eagerContext, tfeContextTy);
1970-
1971-
LLVM_DEBUG(llvm::dbgs() << " Creating call over TFE_RunConstOp().\n");
1972-
auto cTensorHandle = Builder.CreateCall(TFERunConstFn, {eagerContextTyped});
1995+
// Create a TFE op as in:
1996+
// auto* op = TFE_NewOp(ctx, "Const", status);
1997+
auto *TFENewOpFn = IGM.getTFE_NewOpFn();
1998+
// TODO: remove the hard-coded "Const"
1999+
auto opNameValAddr = createStringValAddr(IGM, "Const");
2000+
auto op =
2001+
Builder.CreateCall(TFENewOpFn, {eagerContext, opNameValAddr, status});
2002+
checkOk(status);
2003+
2004+
// Set up dtype attr as in:
2005+
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2006+
auto *setAttrTypeFn = IGM.getTFE_OpSetAttrTypeFn();
2007+
auto dtypeValAddr = createStringValAddr(IGM, "dtype");
2008+
// TODO: do not hard-code 1 as the TF_FLOAT enum value.
2009+
Builder.CreateCall(
2010+
setAttrTypeFn,
2011+
{op, dtypeValAddr, llvm::ConstantInt::get(IGM.Int32Ty, 1 /*TF_FLOAT*/)});
2012+
2013+
auto typeAttr = i->getAttribute(0);
2014+
auto typeAttrInfo =
2015+
GraphOperationInfo::decodeArgumentName(typeAttr.name.str());
2016+
assert(typeAttrInfo.first == "dtype");
2017+
assert(typeAttr.value.getKind() == SymbolicValue::Metatype);
2018+
2019+
auto valueAttr = i->getAttribute(1);
2020+
auto valueAttrInfo =
2021+
GraphOperationInfo::decodeArgumentName(valueAttr.name.str());
2022+
assert(valueAttrInfo.first == "value");
2023+
assert(valueAttr.value.getKind() == SymbolicValue::Float);
2024+
assert(valueAttrInfo.second ==
2025+
GraphOperationInfo::ArgumentLowering::TensorAttribute);
2026+
auto apfloat = valueAttr.value.getFloatValue();
2027+
// CreateScalarFloatTensor() takes an int instead of float, as runtime
2028+
// functions that take/return float values do not yet exist.
2029+
auto constVal = llvm::ConstantInt::get(IGM.Int32Ty, apfloat.convertToFloat());
2030+
LLVM_DEBUG(llvm::dbgs() << "The const value is " << *constVal << ".\n");
2031+
2032+
auto *createTensorFn = IGM.getTFC_CreateScalarFloatTensorFn();
2033+
auto tensor = Builder.CreateCall(createTensorFn, {constVal});
2034+
2035+
// Set up the tensor-typed value attr as in:
2036+
// TFE_OpSetAttrTensor(op, "value", tensor, status);
2037+
auto *setTensorAttrFn = IGM.getTFE_OpSetAttrTensorFn();
2038+
auto valueAttrAddr = createStringValAddr(IGM, "value");
2039+
Builder.CreateCall(setTensorAttrFn, {op, valueAttrAddr, tensor, status});
2040+
checkOk(status);
2041+
2042+
auto *deleteTensorFn = IGM.getTF_DeleteTensorFn();
2043+
Builder.CreateCall(deleteTensorFn, {tensor});
2044+
2045+
// Now we execute the TFE op as in:
2046+
// TFE_TensorHandle* retval;
2047+
// int num_retvals = 1;
2048+
// TFE_Execute(op, &retval, &num_retvals, status);
2049+
//
2050+
// The LLVM IR code looks like:
2051+
// %returnValues = alloca %struct.TFE_TensorHandle*, align 8
2052+
// %returnValueCount = alloca i32, align 8
2053+
// store i32 1, i32* %returnValueCount, align 8
2054+
// %134 = bitcast i8** %returnValues to %struct.TFE_TensorHandle**
2055+
// call void @TFE_Execute(%struct.TFE_Op* %130,
2056+
// %struct.TFE_TensorHandle** %134,
2057+
// i32* %returnValueCount,
2058+
// %struct.TF_Status* %128)
2059+
// %135 = load %struct.TFE_TensorHandle*, %struct.TFE_TensorHandle** %134,
2060+
//
2061+
// FIXME: getPointerAlignment is likely excessive. "align 4" might be
2062+
// sufficient.
2063+
auto returnValueCount =
2064+
createAlloca(IGM.Int32Ty, IGM.getPointerAlignment(), "returnValueCount");
2065+
auto expectedReturnValueCount =
2066+
llvm::ConstantInt::get(IGM.Int32Ty, i->getNumResults());
2067+
Builder.CreateStore(expectedReturnValueCount, returnValueCount);
2068+
auto returnValues = createAlloca(IGM.Int8PtrTy, expectedReturnValueCount,
2069+
IGM.getPointerAlignment(), "returnValues");
2070+
auto *tfeExecuteFn = IGM.getTFE_ExecuteFn();
2071+
Builder.CreateCall(tfeExecuteFn, {op, returnValues.getAddress(),
2072+
returnValueCount.getAddress(), status});
2073+
checkOk(status);
2074+
2075+
// TODO: add sanity check that the returned returnValueCount has value equal
2076+
// to expectedReturnValueCount.
2077+
2078+
// Clean up env as in:
2079+
// TFE_DeleteOp(op);
2080+
// TF_DeleteStatus(status);
2081+
auto *deleteOpFn = IGM.getTFE_DeleteOpFn();
2082+
Builder.CreateCall(deleteOpFn, {op});
2083+
auto *deleteStatusFn = IGM.getTF_DeleteStatusFn();
2084+
Builder.CreateCall(deleteStatusFn, {status});
2085+
2086+
auto cTensorHandle =
2087+
Builder.CreateLoad(returnValues.getAddress(), IGM.getPointerAlignment());
2088+
LLVM_DEBUG(llvm::dbgs() << "The returned tensor handle is " << *cTensorHandle
2089+
<< ".\n");
19732090

19742091
// Wrap `cTensorHandle` into a TensorHandle<T> object.
1975-
// This requires casting `cTensorHandle` of i8* type to
1976-
// %struct.TFE_TensorHandle*.
1977-
LLVM_DEBUG(llvm::dbgs() << "IRGen for creating result TensorHandle.\n");
1978-
auto createHandleSilFn = silModule.findFunction(
1979-
"_swift_tfc_CreateFloatTensorHandleFromCTensorHandle",
1980-
SILLinkage::PublicExternal);
1981-
assert(createHandleSilFn);
1982-
llvm::Function *createHandleFn =
1983-
IGM.getAddrOfSILFunction(createHandleSilFn, NotForDefinition);
1984-
assert(createHandleFn);
1985-
auto *createHandleFnTy = createHandleFn->getFunctionType();
1986-
assert(createHandleFnTy->getNumParams() == 1);
1987-
auto *cTensorHandleTy = createHandleFnTy->getParamType(0);
1988-
LLVM_DEBUG(llvm::dbgs() << " Param 0 of tensor handle creation fn has type "
1989-
<< *cTensorHandleTy << ".\n");
1990-
auto cTensorHandleTyped =
1991-
Builder.CreateBitCast(cTensorHandle, cTensorHandleTy);
1992-
LLVM_DEBUG(llvm::dbgs() << " Creating call over tensor handle creation.\n");
1993-
auto tensorHandle = Builder.CreateCall(createHandleFn, {cTensorHandleTyped});
2092+
llvm::Function *createHandleFn = findFunction(
2093+
"_swift_tfc_CreateFloatTensorHandleFromCTensorHandle", silModule);
2094+
auto tensorHandle = Builder.CreateCall(createHandleFn, {cTensorHandle});
19942095

19952096
LLVM_DEBUG(
19962097
llvm::dbgs() << "Done with IRGen for graph_op; setting explosion.\n");

stdlib/public/CTensorFlow/ctensorflow_init.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "ctensorflow_init.h"
22

3+
#include "tensorflow/c/c_api.h"
4+
#include "tensorflow/c/c_api_experimental.h"
5+
#include "tensorflow/c/eager/c_api.h"
36
#include "tensorflow/core/platform/init_main.h"
47

58
#include <assert.h>
@@ -31,4 +34,22 @@ void InitTensorFlowRuntime(unsigned char enable_debug_logging,
3134
tensorflow::port::InitMain(/*usage=*/nullptr, &my_argc, &tmpArgv);
3235
}
3336

37+
void *swift_tfc_CreateScalarFloatTensor(int32_t val) {
38+
auto *tensor =
39+
TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
40+
TF_DataTypeSize(TF_FLOAT) * 1);
41+
auto *ptr = reinterpret_cast<char *>(TF_TensorData(tensor));
42+
*reinterpret_cast<float *>(ptr) = static_cast<float>(val);
43+
return tensor;
44+
}
45+
46+
void swift_tfc_TFE_Execute(void *op, void **retvals, int32_t *num_retvals,
47+
void *status) {
48+
int int_num_retvals = *num_retvals;
49+
TFE_Execute(reinterpret_cast<TFE_Op *>(op),
50+
reinterpret_cast<TFE_TensorHandle **>(retvals), &int_num_retvals,
51+
reinterpret_cast<TF_Status *>(status));
52+
*num_retvals = int_num_retvals;
53+
}
54+
3455
} // extern "C"

stdlib/public/CTensorFlow/ctensorflow_init.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef SWIFT_SRC_SWIFT_STDLIB_CTENSORFLOW_INIT_H_
22
#define SWIFT_SRC_SWIFT_STDLIB_CTENSORFLOW_INIT_H_
33

4+
#include <stdint.h>
5+
46
#ifdef __cplusplus
57
extern "C" {
68
#endif
@@ -14,6 +16,17 @@ extern "C" {
1416
extern void InitTensorFlowRuntime(unsigned char enable_debug_logging,
1517
int verbose_level);
1618

19+
//===----------------------------------------------------------------------===//
20+
// - MARK: Runtime functions to be called via IRGen.
21+
//===----------------------------------------------------------------------===//
22+
23+
// Caller owns the returned tensor.
24+
// TODO: Generalize to create tensors from other shapes and dtypes.
25+
void *swift_tfc_CreateScalarFloatTensor(int32_t val);
26+
27+
void swift_tfc_TFE_Execute(void *op, void **retvals, int32_t *num_retvals,
28+
void *status);
29+
1730
#ifdef __cplusplus
1831
} /* end extern "C" */
1932
#endif

0 commit comments

Comments
 (0)