Skip to content

Commit d5308bd

Browse files
authored
[DynamicCompilation] Add _AnyTensorHandle and Swift-C round trip entry points (#19529)
1 parent 1c5ea9d commit d5308bd

File tree

7 files changed

+129
-39
lines changed

7 files changed

+129
-39
lines changed

lib/SILOptimizer/Mandatory/TFPartition.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ static bool isUserIgnoredByPartitioning(SILInstruction *inst) {
9494
return isa<RefCountingInst>(inst);
9595
}
9696

97-
/// Given a decl for a struct that has a single field (typically because it is
98-
/// known to be a standard library type like Int or Float), return the canonical
99-
/// type of the single member, asserting and aborting if we get something
100-
/// unexpected.
97+
/// Given a decl for a struct or class that has a single field (typically
98+
/// because it is known to be a standard library type like Int or Float), return
99+
/// the canonical type of the single member, asserting and aborting if we get
100+
/// something unexpected.
101101
static CanType getSingleElementDeclFieldType(NominalTypeDecl *decl) {
102102
auto *field = tf::getFieldIfContainsSingleField(decl);
103103
assert(field && "Struct should have one member");
@@ -3777,7 +3777,13 @@ void TFFunctionPartition::insertTensorComputationStartEndTerminate(
37773777
auto tensorHandleDecl = ctx.getTensorHandleDecl();
37783778
assert(getSingleElementDeclFieldType(tensorHandleDecl) &&
37793779
"TensorHandle should have exactly one field");
3780-
auto tensorHandleMember = *tensorHandleDecl->getStoredProperties().begin();
3780+
auto *anyTensorHandleClass =
3781+
tensorHandleDecl->getSuperclass()->getAnyNominal();
3782+
auto anyTensorHandleSILTy = SILType::getPrimitiveObjectType(
3783+
anyTensorHandleClass->getDeclaredType()->getCanonicalType());
3784+
assert(anyTensorHandleClass);
3785+
auto *tensorHandleMember =
3786+
*anyTensorHandleClass->getStoredProperties().begin();
37813787

37823788
// Ownership markers for CTensorHandle accesses.
37833789
auto loadOwnership = hostFn.hasQualifiedOwnership()
@@ -3839,6 +3845,8 @@ void TFFunctionPartition::insertTensorComputationStartEndTerminate(
38393845
// it. If it is a scalar, then we need to box the scalar in a
38403846
// CTensorHandle.
38413847
if (isTensorHandle(tensorValue->getType().getASTType())) {
3848+
// Upcast to _AnyTensorHandle.
3849+
tensorValue = B.createUpcast(loc, tensorValue, anyTensorHandleSILTy);
38423850
auto fieldAddress =
38433851
B.createRefElementAddr(loc, tensorValue, tensorHandleMember);
38443852
tensorValue = B.createLoad(loc, fieldAddress, loadOwnership);
@@ -3977,8 +3985,9 @@ void TFFunctionPartition::insertTensorComputationStartEndTerminate(
39773985
/*objc*/ false, /*canAllocOnStack*/ false,
39783986
/*elementTypes*/ {},
39793987
/*elementCountOperands*/ {});
3988+
auto baseTH = B.createUpcast(loc, newTH, anyTensorHandleSILTy);
39803989
auto fieldAddress =
3981-
B.createRefElementAddr(result.getLoc(), newTH, tensorHandleMember);
3990+
B.createRefElementAddr(result.getLoc(), baseTH, tensorHandleMember);
39823991

39833992
B.createStore(result.getLoc(), newValue, fieldAddress, storeOwnership);
39843993

lib/SILOptimizer/Mandatory/TFUtilities.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,26 @@ llvm::raw_ostream *tf::getTFDumpIntermediateStream() {
9696
return &fileStream;
9797
}
9898

99-
/// If the specified decl has a single stored field, return it. Otherwise
100-
/// return null.
99+
/// Given a nominal type decl, collect all fields. If it's a class decl, collect
100+
/// all fields along the inheritance hierarchy.
101+
static void getAllFields(NominalTypeDecl *decl,
102+
SmallVectorImpl<VarDecl *> &fields) {
103+
for (auto *field : decl->getStoredProperties())
104+
fields.push_back(field);
105+
if (auto *classdecl = decl->getAsClassOrClassExtensionContext())
106+
if (auto *superclass = classdecl->getSuperclassDecl())
107+
getAllFields(superclass, fields);
108+
}
109+
110+
/// If the specified decl has a single stored field, return it. If it's a class
111+
/// type, return there's exactly one field in the entire inheritance hierarchy.
112+
/// Otherwise return null.
101113
VarDecl *tf::getFieldIfContainsSingleField(NominalTypeDecl *decl) {
102-
// Check to see if there is a single stored field.
103-
auto fieldIt = decl->getStoredProperties().begin();
104-
if (fieldIt == decl->getStoredProperties().end())
105-
return nullptr;
106-
auto result = *fieldIt++;
107-
if (fieldIt != decl->getStoredProperties().end())
108-
return nullptr;
109-
return result;
114+
SmallVector<VarDecl *, 4> fields;
115+
getAllFields(decl, fields);
116+
if (fields.size() == 1)
117+
return fields.front();
118+
return nullptr;
110119
}
111120

112121
bool tf::isTensorHandle(SILType ty) {

stdlib/public/TensorFlow/CompilerRuntime.swift

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,8 @@ internal func dumpCTensorHandleContent(
495495
let dType: TF_DataType = TFE_TensorHandleDataType(inputTensorHandle)
496496
debugLog("Tensor \(idx) has TF data type \(dType).")
497497
switch dType {
498-
case TF_INT8: dumpTensorContent(inputTensorHandle, Int8.self)
499498
case TF_UINT8: dumpTensorContent(inputTensorHandle, UInt8.self)
500-
case TF_INT16: dumpTensorContent(inputTensorHandle, Int16.self)
499+
case TF_INT8: dumpTensorContent(inputTensorHandle, Int8.self)
501500
case TF_UINT16: dumpTensorContent(inputTensorHandle, UInt16.self)
502501
case TF_INT16: dumpTensorContent(inputTensorHandle, Int16.self)
503502
case TF_UINT32: dumpTensorContent(inputTensorHandle, UInt32.self)
@@ -1132,7 +1131,7 @@ public func _TFCTerminateTensorComputation(_ computation: _TensorComputation) {
11321131
/// function.
11331132
@inlinable
11341133
@_silgen_name("_swift_tfc_CreateCTensorHandle")
1135-
public func _TFCCreateCTensorHandle<T>(_ value : T,
1134+
public func _TFCCreateCTensorHandle<T>(_ value: T,
11361135
_ dtype: TF_DataType) -> CTensorHandle {
11371136
// Create a new CTensor and initialize it to the scalar value.
11381137
let tensor = TF_AllocateTensor(dtype, nil, 0, MemoryLayout<T>.stride)
@@ -1170,6 +1169,37 @@ public func _TFCExtractCTensorHandle(
11701169
return handle.cTensorHandle
11711170
}
11721171

1172+
@inlinable
1173+
@_silgen_name("_swift_tfc_GetCTensorHandleFromSwift")
1174+
public func _TFCGetCTensorHandleFromSwift(
1175+
_ handle: _AnyTensorHandle
1176+
) -> CTensorHandle {
1177+
return handle.cTensorHandle
1178+
}
1179+
1180+
@inlinable
1181+
@_silgen_name("_swift_tfc_CreateTensorHandleFromC")
1182+
public func _TFCCreateTensorHandleFromC(
1183+
_ cHandle: CTensorHandle
1184+
) -> _AnyTensorHandle {
1185+
let dtype = TFE_TensorHandleDataType(cHandle)
1186+
switch dtype {
1187+
case TF_BFLOAT16: return TensorHandle<BFloat16>(owning: cHandle)
1188+
case TF_UINT8: return TensorHandle<UInt8>(owning: cHandle)
1189+
case TF_INT8: return TensorHandle<Int8>(owning: cHandle)
1190+
case TF_UINT16: return TensorHandle<UInt16>(owning: cHandle)
1191+
case TF_INT16: return TensorHandle<Int16>(owning: cHandle)
1192+
case TF_UINT32: return TensorHandle<UInt32>(owning: cHandle)
1193+
case TF_INT32: return TensorHandle<Int32>(owning: cHandle)
1194+
case TF_UINT64: return TensorHandle<UInt64>(owning: cHandle)
1195+
case TF_INT64: return TensorHandle<Int64>(owning: cHandle)
1196+
case TF_FLOAT: return TensorHandle<Float>(owning: cHandle)
1197+
case TF_DOUBLE: return TensorHandle<Double>(owning: cHandle)
1198+
case TF_BOOL: return TensorHandle<Bool>(owning: cHandle)
1199+
default: fatalError("Unsupported dtype \(dtype)")
1200+
}
1201+
}
1202+
11731203
@inlinable
11741204
@_silgen_name("_swift_tfc_CreateFloatTensorHandleFromCTensorHandle")
11751205
public func _TFCCreateTensorHandleFromCTensorHandle(

stdlib/public/TensorFlow/DataTypes.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import CTensorFlow
2323

24+
@_fixed_layout
2425
public struct _TensorDataType {
2526
internal var cDataType: TF_DataType
2627

stdlib/public/TensorFlow/TensorHandle.swift

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,52 @@
1616

1717
import CTensorFlow
1818

19-
/// `TensorHandle` is the type used by ops and the `#tfop()` syntax
20-
/// specifically. It includes a `Scalar` type, which compiler internals depend
21-
/// on to determine the datatypes of parameters when they are extracted
22-
/// into a tensor program.
19+
/// `_AnyTensorHandle` is the scalar-agnostic base type for `TensorHandle`, used
20+
/// specifically for low-level, type-erased passings of Swift-level tensor
21+
/// handles in the compiler.
2322
@_fixed_layout // required because the compiler accesses cTensorHandle directly.
24-
public final class TensorHandle<Scalar : AccelerableByTensorFlow> {
23+
public class _AnyTensorHandle {
2524
/// The underlying `TF_TensorHandle *`.
2625
///
27-
/// - Note: The compiler knows that `TensorHandle` has a single stored
26+
/// - Note: The compiler knows that `_AnyTensorHandle` has a single stored
2827
/// property, and assumes that this is it. Changing the design of
2928
/// `TensorHandle` will require tweaking the compiler.
30-
public let cTensorHandle: CTensorHandle
29+
@usableFromInline let cTensorHandle: CTensorHandle
30+
31+
/// Private initializer from a `CTensorHandle`. Should only be called from
32+
/// `TensorHandle<Scalar>.init`.
33+
fileprivate init(base: CTensorHandle) {
34+
self.cTensorHandle = base
35+
}
36+
}
3137

38+
/// `TensorHandle` is the type used by ops and the `#tfop()` syntax
39+
/// specifically. It includes a `Scalar` type, which compiler internals depend
40+
/// on to determine the datatypes of parameters when they are extracted
41+
/// into a tensor program.
42+
@_fixed_layout // required because the compiler accesses cTensorHandle directly.
43+
public final class TensorHandle<Scalar> : _AnyTensorHandle
44+
where Scalar : AccelerableByTensorFlow {
45+
@usableFromInline
46+
init(owning cTensorHandle: CTensorHandle) {
47+
super.init(base: cTensorHandle)
48+
}
49+
3250
@usableFromInline
33-
init(copyingFromCTensor cTensor: CTensor) {
51+
convenience init(copyingFromCTensor cTensor: CTensor) {
3452
let status = TF_NewStatus()
3553
let cTensorHandle = TFE_NewTensorHandle(cTensor, status)
3654
checkOk(status)
37-
self.cTensorHandle = cTensorHandle!
38-
55+
self.init(owning: cTensorHandle!)
3956
TF_DeleteStatus(status)
4057
}
4158

42-
@usableFromInline
43-
init(owning cTensorHandle: CTensorHandle) {
44-
self.cTensorHandle = cTensorHandle
59+
deinit {
60+
debugLog("De-initializing TensorHandle.")
61+
TFE_DeleteTensorHandle(cTensorHandle)
62+
debugLog("Returning from deinit of TensorHandle.")
4563
}
46-
64+
4765
/// Create a `TensorHandle` with a closure that initializes the underlying
4866
/// buffer.
4967
///
@@ -71,12 +89,6 @@ public final class TensorHandle<Scalar : AccelerableByTensorFlow> {
7189
self.init(copyingFromCTensor: cTensor)
7290
TF_DeleteTensor(cTensor)
7391
}
74-
75-
deinit {
76-
debugLog("De-initializing TensorHandle.")
77-
TFE_DeleteTensorHandle(cTensorHandle)
78-
debugLog("Returning from deinit of TensorHandle.")
79-
}
8092
}
8193

8294
internal extension TensorHandle {

test/TensorFlowRuntime/dynamic_compilation.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// This file contains testing over a dataset as a global variable. This requires
77
// sends/recvs support for variant handles.
88

9+
import CTensorFlow
910
import TensorFlow
1011
import TensorFlowUnittest
1112
import StdlibUnittest
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
// REQUIRES: swift_test_mode_optimize
4+
5+
import CTensorFlow
6+
import TensorFlow
7+
import TensorFlowUnittest
8+
import StdlibUnittest
9+
10+
var RuntimeEntryPointTests = TestSuite("RuntimeEntryPoint")
11+
12+
RuntimeEntryPointTests.testCPUOrGPU("RoundTrip_CTensorHandle_AnyTensorHandle") {
13+
let zero: TensorHandle<Float> =
14+
#tfop("Const", dtype: Float.self, value$tensor: 0.0)
15+
var cHandle = _TFCGetCTensorHandleFromSwift(zero as _AnyTensorHandle)
16+
let status = TF_NewStatus()
17+
// We must do a copy, i.e. a retain on the tensor handle, to make sure it won't
18+
// get double-free'd when both `zero` and `anyHandle` below go out of scope.
19+
cHandle = TFE_TensorHandleCopySharingTensor(cHandle, status)
20+
expectEqual(TF_GetCode(status), TF_OK)
21+
TF_DeleteStatus(status)
22+
let anyHandle = _TFCCreateTensorHandleFromC(cHandle)
23+
let tensor = Tensor(handle: anyHandle as! TensorHandle<Float>)
24+
print(tensor)
25+
expectTrue(tensor == Tensor(0.0))
26+
}
27+
28+
runAllTests()

0 commit comments

Comments
 (0)