Skip to content

Commit 7eb1710

Browse files
author
Mingsheng Hong
authored
Added tf dtype generics support in eager mode. (#19588)
* Added tf dtype generics support in eager mode. This is done by extending IRGen to use the new AnyTensorHandle based compiler rt entry points. * Advanced tensorflow repo commit hash to pick up the new experimental API TF_MakeInternalErrorStatus(). * Addressed feedback.
1 parent 2b6b6ce commit 7eb1710

File tree

7 files changed

+114
-48
lines changed

7 files changed

+114
-48
lines changed

include/swift/Runtime/RuntimeFunctions.def

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,11 @@ FUNCTION(TFC_CreateScalarFloatTensor, swift_tfc_CreateScalarFloatTensor, C_CC,
13621362
ARGS(Int32Ty),
13631363
ATTRS(NoUnwind))
13641364

1365+
FUNCTION(TFC_CreateScalarIntTensor, swift_tfc_CreateScalarIntTensor, C_CC,
1366+
RETURNS(Int8PtrTy),
1367+
ARGS(Int64Ty, Int32Ty, Int8PtrTy),
1368+
ATTRS(NoUnwind))
1369+
13651370
FUNCTION(TFE_Execute, swift_tfc_TFE_Execute, C_CC,
13661371
RETURNS(),
13671372
ARGS(Int8PtrTy, Int8PtrPtrTy, Int32PtrTy, Int8PtrTy),
@@ -1373,15 +1378,15 @@ FUNCTION(TFC_GetGlobalEagerContext, _swift_tfc_GetGlobalEagerContext, C_CC,
13731378
ATTRS(NoUnwind))
13741379

13751380
// TODO: enable these decls once we have AnyTensorHandle.
1376-
// FUNCTION(TFC_ExtractFloatCTensorHandle, _swift_tfc_ExtractFloatCTensorHandle, C_CC,
1377-
// RETURNS(Int8PtrTy),
1378-
// ARGS(Int8PtrTy),
1379-
// ATTRS(NoUnwind))
1380-
1381-
// FUNCTION(TFC_CreateFloatTensorHandleFromCTensorHandle, _swift_tfc_CreateFloatTensorHandleFromCTensorHandle, C_CC,
1382-
// RETURNS(Int8PtrTy),
1383-
// ARGS(Int8PtrTy),
1384-
// ATTRS(NoUnwind))
1381+
FUNCTION(TFC_GetCTensorHandleFromSwift, _swift_tfc_GetCTensorHandleFromSwift, C_CC,
1382+
RETURNS(Int8PtrTy),
1383+
ARGS(Int8PtrTy),
1384+
ATTRS(NoUnwind))
1385+
1386+
FUNCTION(TFC_CreateTensorHandleFromC, _swift_tfc_CreateTensorHandleFromC, C_CC,
1387+
RETURNS(Int8PtrTy),
1388+
ARGS(Int8PtrTy),
1389+
ATTRS(NoUnwind))
13851390

13861391
FUNCTION(TFC_CheckOk, _swift_tfc_CheckOk, C_CC,
13871392
RETURNS(),

lib/IRGen/IRGenSIL.cpp

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,8 +1968,6 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
19681968
// 2. Run the graph_op
19691969
// 3. Set the output tensor handles via setLoweredExplosion()
19701970

1971-
auto &silModule = CurSILFn->getModule();
1972-
19731971
auto *TFNewStatusFn = IGM.getTF_NewStatusFn();
19741972
auto status = Builder.CreateCall(TFNewStatusFn, {});
19751973

@@ -2015,8 +2013,7 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
20152013

20162014
auto tensorHandleValue =
20172015
getLoweredSingletonExplosion(tensorHandleSilValue);
2018-
llvm::Function *extractHandleFn =
2019-
findFunction("_swift_tfc_ExtractFloatCTensorHandle", silModule);
2016+
auto *extractHandleFn = IGM.getTFC_GetCTensorHandleFromSwiftFn();
20202017
auto cHandle = Builder.CreateCall(extractHandleFn, {tensorHandleValue});
20212018

20222019
// Add an op input as in:
@@ -2091,18 +2088,37 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
20912088
i->dump();
20922089
llvm_unreachable("dtype attr must have been processed!");
20932090
}
2094-
if (attr.value.getKind() != SymbolicValue::Float) {
2091+
2092+
llvm::Value *tensor = nullptr;
2093+
switch (attr.value.getKind()) {
2094+
case SymbolicValue::Float: {
2095+
auto apfloat = attr.value.getFloatValue();
2096+
// CreateScalarFloatTensor() takes an int instead of float, as runtime
2097+
// functions that take/return float values do not yet exist.
2098+
auto constVal =
2099+
llvm::ConstantInt::get(IGM.Int32Ty, apfloat.convertToFloat());
2100+
LLVM_DEBUG(llvm::dbgs() << "The const value is " << *constVal << ".\n");
2101+
2102+
auto *createTensorFn = IGM.getTFC_CreateScalarFloatTensorFn();
2103+
tensor = Builder.CreateCall(createTensorFn, {constVal});
2104+
break;
2105+
}
2106+
case SymbolicValue::Integer: {
2107+
auto apint = attr.value.getIntegerValue();
2108+
auto constVal = llvm::ConstantInt::get(
2109+
IGM.Int64Ty, apint.sextOrTrunc(64).getLimitedValue());
2110+
LLVM_DEBUG(llvm::dbgs() << "The const value is " << *constVal << ".\n");
2111+
2112+
auto *createTensorFn = IGM.getTFC_CreateScalarIntTensorFn();
2113+
tensor = Builder.CreateCall(
2114+
createTensorFn,
2115+
{constVal, llvm::ConstantInt::get(IGM.Int32Ty, dtypeAttr), status});
2116+
checkOk(status);
2117+
break;
2118+
}
2119+
default:
20952120
llvm_unreachable("TODO: support other dtypes for tensor attr.");
20962121
}
2097-
auto apfloat = attr.value.getFloatValue();
2098-
// CreateScalarFloatTensor() takes an int instead of float, as runtime
2099-
// functions that take/return float values do not yet exist.
2100-
auto constVal =
2101-
llvm::ConstantInt::get(IGM.Int32Ty, apfloat.convertToFloat());
2102-
LLVM_DEBUG(llvm::dbgs() << "The const value is " << *constVal << ".\n");
2103-
2104-
auto *createTensorFn = IGM.getTFC_CreateScalarFloatTensorFn();
2105-
auto tensor = Builder.CreateCall(createTensorFn, {constVal});
21062122

21072123
// Set up the tensor-typed value attr as in:
21082124
// TFE_OpSetAttrTensor(op, "value", tensor, status);
@@ -2167,8 +2183,7 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
21672183
<< ".\n");
21682184

21692185
// Wrap `cTensorHandle` into a TensorHandle<T> object.
2170-
llvm::Function *createHandleFn = findFunction(
2171-
"_swift_tfc_CreateFloatTensorHandleFromCTensorHandle", silModule);
2186+
auto *createHandleFn = IGM.getTFC_CreateTensorHandleFromCFn();
21722187
auto tensorHandle = Builder.CreateCall(createHandleFn, {cTensorHandle});
21732188

21742189
LLVM_DEBUG(

stdlib/public/CTensorFlow/ctensorflow_init.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,46 @@ void *swift_tfc_CreateScalarFloatTensor(int32_t val) {
4343
return tensor;
4444
}
4545

46+
void *swift_tfc_CreateScalarIntTensor(int64_t val, int32_t dtype,
47+
TF_Status *status) {
48+
auto tfDtype = (TF_DataType)dtype;
49+
auto *tensor =
50+
TF_AllocateTensor(tfDtype, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
51+
TF_DataTypeSize(tfDtype) * 1);
52+
auto *ptr = reinterpret_cast<char *>(TF_TensorData(tensor));
53+
54+
switch (tfDtype) {
55+
case TF_INT8:
56+
*reinterpret_cast<int8_t *>(ptr) = static_cast<int8_t>(val);
57+
break;
58+
case TF_UINT8:
59+
*reinterpret_cast<uint8_t *>(ptr) = static_cast<uint8_t>(val);
60+
break;
61+
case TF_INT16:
62+
*reinterpret_cast<int16_t *>(ptr) = static_cast<int16_t>(val);
63+
break;
64+
case TF_UINT16:
65+
*reinterpret_cast<uint16_t *>(ptr) = static_cast<uint16_t>(val);
66+
break;
67+
case TF_INT32:
68+
*reinterpret_cast<int32_t *>(ptr) = static_cast<int32_t>(val);
69+
break;
70+
case TF_UINT32:
71+
*reinterpret_cast<uint32_t *>(ptr) = static_cast<uint32_t>(val);
72+
break;
73+
case TF_INT64:
74+
*reinterpret_cast<int64_t *>(ptr) = static_cast<int64_t>(val);
75+
break;
76+
case TF_UINT64:
77+
*reinterpret_cast<uint64_t *>(ptr) = static_cast<uint64_t>(val);
78+
break;
79+
default:
80+
TF_MakeInternalErrorStatus(status, "Unsupported data type");
81+
return nullptr;
82+
}
83+
return tensor;
84+
}
85+
4686
void swift_tfc_TFE_Execute(void *op, void **retvals, int32_t *num_retvals,
4787
void *status) {
4888
int int_num_retvals = *num_retvals;

stdlib/public/CTensorFlow/ctensorflow_init.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ extern void InitTensorFlowRuntime(unsigned char enable_debug_logging,
2424
// TODO: Generalize to create tensors from other shapes and dtypes.
2525
void *swift_tfc_CreateScalarFloatTensor(int32_t val);
2626

27+
struct TF_Status;
28+
//`val` will be cast to the C data type based on `dtype`, which is then used to
29+
// create the scalar tensor. e.g. For dtype = TF_INT8, int8_t will be used.
30+
void *swift_tfc_CreateScalarIntTensor(int64_t val, int32_t dtype,
31+
TF_Status *status);
32+
2733
void swift_tfc_TFE_Execute(void *op, void **retvals, int32_t *num_retvals,
2834
void *status);
2935

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,19 +1156,9 @@ func _TFCGetGlobalEagerContext() -> CTFEContext {
11561156
return _ExecutionContext.global.eagerContext
11571157
}
11581158

1159-
// TODO: replace these functions with generic ones that do not hard-code Float.
1160-
1161-
// TODO: use @_cdecl instead, once we make the input/output data types C-compatible.
1162-
// Current compiler error if we use @_cdecl: method cannot be marked @_cdecl
1163-
// because the type of the parameter cannot be represented in Objective-C
1164-
@inlinable
1165-
@_silgen_name("_swift_tfc_ExtractFloatCTensorHandle")
1166-
public func _TFCExtractCTensorHandle(
1167-
_ handle: TensorHandle<Float>
1168-
) -> CTensorHandle {
1169-
return handle.cTensorHandle
1170-
}
1171-
1159+
// Some of the functions are marked with @silgen_name instead of @_cdecl,
1160+
// because their input/output data types are not C-compatible
1161+
// (e.g. AnyTensorHandle).
11721162
@inlinable
11731163
@_silgen_name("_swift_tfc_GetCTensorHandleFromSwift")
11741164
public func _TFCGetCTensorHandleFromSwift(
@@ -1200,14 +1190,6 @@ public func _TFCCreateTensorHandleFromC(
12001190
}
12011191
}
12021192

1203-
@inlinable
1204-
@_silgen_name("_swift_tfc_CreateFloatTensorHandleFromCTensorHandle")
1205-
public func _TFCCreateTensorHandleFromCTensorHandle(
1206-
_ ownedCHandle: CTensorHandle
1207-
) -> TensorHandle<Float> {
1208-
return TensorHandle<Float>(owning: ownedCHandle)
1209-
}
1210-
12111193
@usableFromInline
12121194
@_cdecl("_swift_tfc_CheckOk")
12131195
func _TFCCheckOk(_ s: CTFStatus) {

test/TensorFlowRuntime/dynamic_compilation.swift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ DynamicCompilationTests.testCPUOrGPU("Const") {
2323
expectNearlyEqualWithScalarTensor(1.0, Tensor<Float>(handle: x))
2424
}
2525

26-
DynamicCompilationTests.testCPUOrGPU("Add") {
26+
DynamicCompilationTests.testCPUOrGPU("AddFloat") {
2727
_RuntimeConfig.printsDebugLog = true
2828
let x = Tensor<Float>(1.0)
2929
let y = Tensor<Float>(2.0)
@@ -32,6 +32,24 @@ DynamicCompilationTests.testCPUOrGPU("Add") {
3232
expectNearlyEqualWithScalarTensor(3.0, z)
3333
}
3434

35+
DynamicCompilationTests.testCPUOrGPU("AddInt64") {
36+
_RuntimeConfig.printsDebugLog = true
37+
let x = Tensor<Int64>(1)
38+
let y = Tensor<Int64>(2)
39+
let z = x + y
40+
_hostOp(z)
41+
expectEqualWithScalarTensor(3, z)
42+
}
43+
44+
DynamicCompilationTests.testCPUOrGPU("AddInt32") {
45+
_RuntimeConfig.printsDebugLog = true
46+
let x = Tensor<Int32>(1)
47+
let y = Tensor<Int32>(2)
48+
let z = x + y
49+
_hostOp(z)
50+
expectEqualWithScalarTensor(3, z)
51+
}
52+
3553
#endif // !CUDA
3654

3755
runAllTests()

utils/update_checkout/update-checkout-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
"swift-integration-tests": "swift-DEVELOPMENT-SNAPSHOT-2018-08-06-a",
305305
"swift-xcode-playground-support": "swift-DEVELOPMENT-SNAPSHOT-2018-08-06-a",
306306
"ninja": "253e94c1fa511704baeb61cf69995bbf09ba435e",
307-
"tensorflow": "b5594e6121e902f8dd2d5127653a1ec5f97daccd",
307+
"tensorflow": "bdab0b3c111bbe1c9656fa2228f1a4d28df5a7bf",
308308
"tensorflow-swift-bindings": "e1983bdac0c64ba02f8c5c850f7c82436b5622e5"
309309
}
310310
}

0 commit comments

Comments
 (0)