Skip to content

Commit cfe5d4f

Browse files
author
Mingsheng Hong
authored
Factored IRGenSILFunction::visitGraphOperationInst() to handle ops other than "Const". (#19575)
* Factored IRGenSILFunction::visitGraphOperationInst() to handle ops other than "Const". Added a new unit test based on "Add". * Polished the code a bit. * Reverted #include reordering made by clang-format, to avoid unnecessary potential merge conflicts when we downstream code from the master branch.
1 parent d5308bd commit cfe5d4f

File tree

11 files changed

+255
-151
lines changed

11 files changed

+255
-151
lines changed

include/swift/AST/TensorFlow.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ namespace tf {
3030
/// Return true if the given type represents a TensorFlow dtype.
3131
bool isTensorFlowDType(Type ty);
3232

33+
/// This function maps a Swift type (either a language type like Float or an
34+
/// LLVM Builtin type like Builtin.f32) into the TensorFlow TF_DataType value.
35+
unsigned convertSwiftTypeToTF(Type ty);
36+
3337
/// If the specified type is the well-known TensorHandle<T> type, then return
3438
/// "T". If not, return a null type.
3539
Type getTensorHandleElementType(Type ty);

include/swift/Runtime/RuntimeFunctions.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,11 @@ FUNCTION(TFE_DeleteOp, TFE_DeleteOp, C_CC,
13361336
ARGS(Int8PtrTy),
13371337
ATTRS(NoUnwind))
13381338

1339+
FUNCTION(TFE_OpAddInput, TFE_OpAddInput, C_CC,
1340+
RETURNS(),
1341+
ARGS(Int8PtrTy, Int8PtrTy, Int8PtrTy),
1342+
ATTRS(NoUnwind))
1343+
13391344
FUNCTION(TFE_OpSetAttrType, TFE_OpSetAttrType, C_CC,
13401345
RETURNS(),
13411346
ARGS(Int8PtrTy, Int8PtrTy, Int32Ty),

include/swift/SIL/GraphOperationInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#define SWIFT_SIL_GRAPH_OPERATION_INFO_H
2020

2121
#include "swift/AST/Identifier.h"
22+
#include "swift/AST/TensorFlow.h"
23+
#include "swift/SIL/SILType.h"
2224
#include "swift/SIL/SILValue.h"
2325
#include "llvm/ADT/SmallVector.h"
2426
#include "llvm/ADT/StringRef.h"
@@ -185,6 +187,11 @@ struct GraphOperationInfo {
185187
static std::pair<llvm::StringRef, ArgumentLowering>
186188
decodeArgumentName(StringRef Name);
187189
};
190+
191+
/// Determine whether the specified type is one of our well-known types, and
192+
/// if so, which one it is.
193+
TFValueKind classifyTensorFlowValue(SILType ty);
194+
188195
} // end namespace tf
189196
} // end namespace swift
190197
#endif // SWIFT_SIL_GRAPH_OPERATION_INFO_H

lib/AST/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ if (SWIFT_FORCE_OPTIMIZED_TYPECHECKER)
33
set(EXTRA_AST_FLAGS "FORCE_BUILD_OPTIMIZED")
44
endif()
55

6+
if(SWIFT_ENABLE_TENSORFLOW)
7+
find_package(TensorFlow REQUIRED)
8+
include_directories(BEFORE "${TF_INCLUDE_DIR}")
9+
endif()
10+
611
add_swift_library(swiftAST STATIC
712
AccessScopeChecker.cpp
813
AccessRequests.cpp

lib/AST/GraphOperationInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,9 @@ std::pair<StringRef, GraphOperationInfo::ArgumentLowering>
149149
GraphOperationInfo::StructuredArgument::getArgumentNameAndLowering() const {
150150
return decodeArgumentName(Name);
151151
}
152+
153+
/// Determine whether the specified type is one of our well-known types, and
154+
/// if so, which one it is.
155+
TFValueKind tf::classifyTensorFlowValue(SILType ty) {
156+
return classifyTensorFlowValue(ty.getASTType());
157+
}

lib/AST/TensorFlow.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
#include "swift/AST/TensorFlow.h"
1919
#include "swift/AST/Decl.h"
20+
#include "swift/AST/Module.h"
2021
#include "swift/AST/Types.h"
22+
#ifdef SWIFT_ENABLE_TENSORFLOW
23+
#include "tensorflow/c/c_api.h"
24+
#endif
2125
using namespace swift;
2226
using namespace tf;
2327

@@ -36,6 +40,88 @@ bool tf::isTensorFlowDType(Type ty) {
3640
return !conformances.empty();
3741
}
3842

43+
static bool is64(Type ty) {
44+
return ty->getASTContext().LangOpts.Target.isArch64Bit();
45+
}
46+
47+
/// This function maps a Swift type (either a language type like Float or an
48+
/// LLVM Builtin type like Builtin.f32) into the TensorFlow TF_DataType value.
49+
///
50+
/// This returns 0 (which is an invalid tensorflow type ID) on error.
51+
///
52+
unsigned tf::convertSwiftTypeToTF(Type ty) {
53+
#ifdef SWIFT_ENABLE_TENSORFLOW
54+
// Handle wrappers like Float, which come up in TensorHandle<Float>
55+
if (auto *s = ty->getAs<StructType>()) {
56+
// Make sure the type is defined inside the Swift module.
57+
auto context = s->getDecl()->getDeclContext()->getParentModule();
58+
if (!context || context->getName().str() != "Swift")
59+
return 0;
60+
61+
return llvm::StringSwitch<unsigned>(s->getDecl()->getNameStr())
62+
.Case("Bool", TF_BOOL)
63+
.Case("Int8", TF_INT8)
64+
.Case("UInt8", TF_UINT8)
65+
.Case("Int16", TF_INT16)
66+
.Case("UInt16", TF_UINT16)
67+
.Case("Int32", TF_INT32)
68+
.Case("UInt32", TF_UINT32)
69+
.Case("Int64", TF_INT64)
70+
.Case("UInt64", TF_UINT64)
71+
.Case("Int8", TF_INT8)
72+
.Case("UInt8", TF_UINT8)
73+
.Case("BFloat16", TF_BFLOAT16)
74+
.Case("Float", TF_FLOAT)
75+
.Case("Double", TF_DOUBLE)
76+
.Case("Int", is64(s) ? TF_INT64 : TF_INT32)
77+
.Case("UInt", is64(s) ? TF_UINT64 : TF_UINT32)
78+
.Case("String", TF_STRING)
79+
.Default(0);
80+
}
81+
82+
// BuiltinIntegerType doesn't carry sign information, which TensorFlow needs,
83+
// so we can't rely on getting type information from the builtin types
84+
// themselves. For now we'll just use signed types.
85+
if (auto *BII = ty->getAs<BuiltinIntegerType>()) {
86+
if (BII->getWidth().isPointerWidth())
87+
return is64(ty) ? TF_INT64 : TF_INT32;
88+
89+
switch (BII->getFixedWidth()) {
90+
case 1:
91+
return TF_BOOL;
92+
case 8:
93+
return TF_INT8;
94+
case 16:
95+
return TF_INT16;
96+
case 32:
97+
return TF_INT32;
98+
case 64:
99+
return TF_INT64;
100+
}
101+
}
102+
103+
if (auto *BIF = ty->getAs<BuiltinFloatType>()) {
104+
switch (BIF->getFPKind()) {
105+
case BuiltinFloatType::IEEE16:
106+
return TF_HALF;
107+
case BuiltinFloatType::IEEE32:
108+
return TF_FLOAT;
109+
case BuiltinFloatType::IEEE64:
110+
return TF_DOUBLE;
111+
case BuiltinFloatType::IEEE80:
112+
case BuiltinFloatType::IEEE128:
113+
case BuiltinFloatType::PPC128:
114+
return 0;
115+
}
116+
}
117+
118+
if (auto *BRPT = ty->getAs<BuiltinRawPointerType>()) {
119+
return TF_STRING;
120+
}
121+
#endif
122+
return 0;
123+
}
124+
39125
/// If the specified type is the well-known TensorHandle<T> type, then return
40126
/// "T". If not, return a null type.
41127
Type tf::getTensorHandleElementType(Type ty) {

0 commit comments

Comments
 (0)