Skip to content

[TF] Moved most Tensor APIs to 'tensorflow/swift-apis'. #24161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
de8fdbd
Moved a couple of tensor initializers to swift-apis.
eaplatanios Apr 19, 2019
06b96c0
Moved the activation functions to swift-apis.
eaplatanios Apr 19, 2019
dcd46be
Minor edit.
eaplatanios Apr 19, 2019
7a957cc
Moved the log-softmax VJP to swift-apis.
eaplatanios Apr 19, 2019
75f7039
Moved some tensor initializers to swift-apis.
eaplatanios Apr 20, 2019
a897918
Moved some more stuff to swift-apis.
eaplatanios Apr 20, 2019
da184a8
Moved some more stuff to swift-apis.
eaplatanios Apr 20, 2019
cde450e
Moved some more stuff to swift-apis.
eaplatanios Apr 20, 2019
a878600
Removed the now-redundant 'Ops.swift' file.
eaplatanios Apr 20, 2019
e790f78
Moved the gradient helper methods to swift-apis.
eaplatanios Apr 20, 2019
93041e0
Moved the tensor tests to swift-apis.
eaplatanios Apr 20, 2019
d988084
Brought back the tensor APItests.
eaplatanios Apr 20, 2019
a238b6e
Added support for the TensorFlow op.
eaplatanios Apr 20, 2019
d759e24
Bug fix.
eaplatanios Apr 20, 2019
944d7f6
Bug fix.
eaplatanios Apr 20, 2019
12ec483
Updated the swift-apis dependency.
eaplatanios Apr 21, 2019
97d0ea8
Merged upstream changes.
eaplatanios Apr 21, 2019
993c972
Minor edit.
eaplatanios Apr 21, 2019
aa72c70
Added support for 'Dataset.repeated(count:)' since '#tfop' does not w…
eaplatanios Apr 21, 2019
4f8c2ca
Added support for prefetched datasets.
eaplatanios Apr 22, 2019
98e5704
Moved the dataset ops to swift-apis.
eaplatanios Apr 23, 2019
df0ec40
Removed the now-redundant 'ArrayOps.swift' file.
eaplatanios Apr 23, 2019
ce4dfd3
Changes to support the new swift-bindings.
eaplatanios Apr 23, 2019
701e31d
Updated the 'TensorArrayProtocol' and its automatic derivation implem…
eaplatanios Apr 23, 2019
a5fcdc2
Bug fixes.
eaplatanios Apr 24, 2019
22fed17
Minor edits.
eaplatanios Apr 24, 2019
2de2d2b
Addressed Dan's comments regarding the 'TensorArrayProtocol' derived …
eaplatanios Apr 24, 2019
93e335a
Minor bug fix.
eaplatanios Apr 24, 2019
6a79ac8
Minor edit.
eaplatanios Apr 24, 2019
5890133
Enhancements to 'TensorArrayProtocol'.
eaplatanios Apr 24, 2019
baed6e3
TensorFlow/TensorFlowCore refactoring.
eaplatanios Apr 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ Below is more information about TensorFlow-related build arguments.
* Default value: None.
* `tensorflow-swift-apis`: A path to the [tensorflow/swift-apis](https://github.com/tensorflow/swift-apis) deep learning library repository.
* Default value: `tensorflow-swift-apis` if the [tensorflow/swift-apis](https://github.com/tensorflow/swift-apis) repository is cloned. Otherwise, none.
* `tensorflow-swift-bindings`: A generated TensorFlow Swift bindings file (`RawOpsGenerated.swift`) obtained from [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings).
* Default value: `tensorflow-swift-bindings/RawOpsGenerated.swift` if the [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings) repository is cloned. Otherwise, none.
* `tensorflow-swift-bindings`: A path to the [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings) repository.
* Default value: `tensorflow-swift-bindings` if the [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings) repository is cloned. Otherwise, none.

### Build systems

Expand Down
3 changes: 2 additions & 1 deletion cmake/modules/SwiftSource.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ function(_compile_swift_files
# Also, disable it for DifferentiationUnittest because resilience changes
# the AD code # that gets generated (leading to additional leaks)
# (see: TF-328)
if(NOT "${SWIFTFILE_MODULE_NAME}" STREQUAL "TensorFlow" AND
if(NOT "${SWIFTFILE_MODULE_NAME}" STREQUAL "TensorFlowCore" AND
NOT "${SWIFTFILE_MODULE_NAME}" STREQUAL "TensorFlow" AND
NOT "${SWIFTFILE_MODULE_NAME}" STREQUAL "DifferentiationUnittest")
list(APPEND swift_flags "-Xfrontend" "-enable-resilience")
endif()
Expand Down
12 changes: 6 additions & 6 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,16 @@ class ASTContext final {
CanType getAnyObjectType() const;

// SWIFT_ENABLE_TENSORFLOW
/// Retrieve the decl for TensorFlow.TensorHandle iff the TensorFlow module
/// has been imported. Otherwise, this returns null.
/// Retrieve the decl for TensorFlowCore.TensorHandle iff the TensorFlowCore
/// module has been imported. Otherwise, this returns null.
ClassDecl *getTensorHandleDecl() const;

/// Retrieve the decl for TensorFlow.TensorShape iff the TensorFlow module
/// has been imported. Otherwise, this returns null.
/// Retrieve the decl for TensorFlowCore.TensorShape iff the TensorFlowCore
/// module has been imported. Otherwise, this returns null.
StructDecl *getTensorShapeDecl() const;

/// Retrieve the decl for TensorFlow.TensorDataType iff the TensorFlow module
/// has been imported. Otherwise, this returns null.
/// Retrieve the decl for TensorFlowCore.TensorDataType iff the TensorFlowCore
/// module has been imported. Otherwise, this returns null.
StructDecl *getTensorDataTypeDecl() const;

/// Retrieve the type for Swift._AutoDiffTape.
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ IDENTIFIER(withArguments)
IDENTIFIER(withKeywordArguments)

// SWIFT_ENABLE_TENSORFLOW
IDENTIFIER(TensorFlow)
IDENTIFIER(TensorFlowCore)
// KeyPathIterable
IDENTIFIER(AllKeyPaths)
IDENTIFIER(allKeyPaths)
Expand Down
22 changes: 11 additions & 11 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,14 +821,14 @@ CanType ASTContext::getAnyObjectType() const {
}

// SWIFT_ENABLE_TENSORFLOW
/// Retrieve the decl for TensorFlow.TensorHandle iff the TensorFlow module has
/// been imported. Otherwise, this returns null.
/// Retrieve the decl for TensorFlowCore.TensorHandle iff the TensorFlow module
/// has been imported. Otherwise, this returns null.
ClassDecl *ASTContext::getTensorHandleDecl() const {
if (getImpl().TensorHandleDecl)
return getImpl().TensorHandleDecl;

// See if the TensorFlow module was imported. If not, return null.
auto tfModule = getLoadedModule(Id_TensorFlow);
auto tfModule = getLoadedModule(Id_TensorFlowCore);
if (!tfModule)
return nullptr;

Expand All @@ -842,14 +842,14 @@ ClassDecl *ASTContext::getTensorHandleDecl() const {
return nullptr;
}

/// Retrieve the decl for TensorFlow.TensorShape iff the TensorFlow module has
/// been imported. Otherwise, this returns null.
/// Retrieve the decl for TensorFlowCore.TensorShape iff the TensorFlow module
/// has been imported. Otherwise, this returns null.
StructDecl *ASTContext::getTensorShapeDecl() const {
if (getImpl().TensorShapeDecl)
return getImpl().TensorShapeDecl;

// See if the TensorFlow module was imported. If not, return null.
auto tfModule = getLoadedModule(Id_TensorFlow);
auto tfModule = getLoadedModule(Id_TensorFlowCore);
if (!tfModule)
return nullptr;

Expand All @@ -863,14 +863,14 @@ StructDecl *ASTContext::getTensorShapeDecl() const {
return nullptr;
}

/// Retrieve the decl for TensorFlow.TensorDataType iff the TensorFlow module has
/// been imported. Otherwise, this returns null.
/// Retrieve the decl for TensorFlowCore.TensorDataType iff the TensorFlow
/// module has been imported. Otherwise, this returns null.
StructDecl *ASTContext::getTensorDataTypeDecl() const {
if (getImpl().TensorDataTypeDecl)
return getImpl().TensorDataTypeDecl;

// See if the TensorFlow module was imported. If not, return null.
auto tfModule = getLoadedModule(Id_TensorFlow);
auto tfModule = getLoadedModule(Id_TensorFlowCore);
if (!tfModule)
return nullptr;

Expand Down Expand Up @@ -987,7 +987,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::TensorFlowDataTypeCompatible:
case KnownProtocolKind::TensorSendableReceivable:
case KnownProtocolKind::TensorProtocol:
M = getLoadedModule(Id_TensorFlow);
M = getLoadedModule(Id_TensorFlowCore);
break;
default:
M = getStdlibModule();
Expand Down Expand Up @@ -1886,7 +1886,7 @@ ASTContext::getModule(ArrayRef<std::pair<Identifier, SourceLoc>> ModulePath) {
(ModulePath[0].first == StdlibModuleName ||
ModulePath[0].first == Id_Foundation ||
// SWIFT_ENABLE_TENSORFLOW
ModulePath[0].first == Id_TensorFlow))
ModulePath[0].first == Id_TensorFlowCore))
recordKnownProtocols(M);
return M;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,7 @@ void IRGenSILFunction::visitGraphOperationInst(GraphOperationInst *i) {
tf::GraphOperationInfo opInfo(i);

// TODO: As an optimization, do this lookup once per CurSILFn
auto tfModule = astCtx.getLoadedModule(astCtx.Id_TensorFlow);
auto tfModule = astCtx.getLoadedModule(astCtx.Id_TensorFlowCore);
assert(tfModule && "could not find TensorFlow module");
auto inputTensorGroupProto =
astCtx.getProtocol(KnownProtocolKind::TensorArrayProtocol);
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/TFDeabstraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2603,7 +2603,7 @@ void TFDeabstractionPass::run() {
// If the TensorFlow module hasn't been imported by the program, don't do
// anything. This avoids impacting compile time for non-TensorFlow using
// Swift programs by doing extraneous analysis.
auto tfModule = ctx.getLoadedModule(ctx.Id_TensorFlow);
auto tfModule = ctx.getLoadedModule(ctx.Id_TensorFlowCore);
if (!tfModule)
return;

Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/TFPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4478,7 +4478,7 @@ void TFPartition::run() {
// If the TensorFlow module hasn't been imported by the program, don't do
// anything. This avoids impacting compile time for non-TensorFlow using
// Swift programs by doing extraneous analysis.
tfModule = ctx.getLoadedModule(ctx.Id_TensorFlow);
tfModule = ctx.getLoadedModule(ctx.Id_TensorFlowCore);
if (!tfModule)
return;

Expand Down
1 change: 0 additions & 1 deletion lib/SILOptimizer/PassManager/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ void swift::runSILTFPartitionPass(SILModule &Module) {
// Verify the module, if required.
if (Module.getOptions().VerifyAll)
Module.verify();

}

void swift::runSILOptimizationPassesWithFileSpecification(SILModule &M,
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2424,8 +2424,8 @@ namespace {
// The result type must conform to TensorGroup or be a tuple of types that
// conform to TensorGroup.

auto tfModule = ctx.getLoadedModule(ctx.Id_TensorFlow);
assert(tfModule && "could not find TensorFlow module");
auto tfModule = ctx.getLoadedModule(ctx.Id_TensorFlowCore);
assert(tfModule && "could not find TensorFlowCore module");
auto tensorGroupProto = ctx.getProtocol(KnownProtocolKind::TensorGroup);
assert(tensorGroupProto && "could not find TensorGroup protocol");

Expand Down
Loading