Skip to content

[AutoDiff] Enable cross-file derivative registration by default. #31249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,6 @@ namespace swift {
bool EnableExperimentalDifferentiableProgramming = true;
// SWIFT_ENABLE_TENSORFLOW END

// SWIFT_ENABLE_TENSORFLOW
/// Whether to enable cross-file derivative registration.
bool EnableExperimentalCrossFileDerivativeRegistration = false;
// SWIFT_ENABLE_TENSORFLOW END

/// Whether to enable forward mode differentiation.
bool EnableExperimentalForwardModeDifferentiation = false;

Expand Down
8 changes: 0 additions & 8 deletions include/swift/Option/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,6 @@ def disable_bridging_pch : Flag<["-"], "disable-bridging-pch">,

// Experimental feature options

// SWIFT_ENABLE_TENSORFLOW
// TODO(TF-1097): Remove this flag and always enable cross-file derivative registration.
def enable_experimental_cross_file_derivative_registration :
Flag<["-"], "enable-experimental-cross-file-derivative-registration">,
Flags<[FrontendOption, ModuleInterfaceOption]>,
HelpText<"Enable experimental cross-file derivative registration">;
// SWIFT_ENABLE_TENSORFLOW END

// Note: this flag will be removed when JVP/differential generation in the
// differentiation transform is robust.
def enable_experimental_forward_mode_differentiation :
Expand Down
6 changes: 0 additions & 6 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,6 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
if (Args.hasArg(OPT_enable_experimental_additive_arithmetic_derivation))
Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true;

// SWIFT_ENABLE_TENSORFLOW
// TODO(TF-1097): Remove this flag.
Opts.EnableExperimentalCrossFileDerivativeRegistration |=
Args.hasArg(OPT_enable_experimental_cross_file_derivative_registration);
// SWIFT_ENABLE_TENSORFLOW END

Opts.EnableExperimentalForwardModeDifferentiation |=
Args.hasArg(OPT_enable_experimental_forward_mode_differentiation);

Expand Down
11 changes: 0 additions & 11 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4441,17 +4441,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return true;
}

// SWIFT_ENABLE_TENSORFLOW
// Reject different-file derivative registration.
// TODO(TF-1021): Lift same-file derivative registration restriction.
if (!Ctx.LangOpts.EnableExperimentalCrossFileDerivativeRegistration &&
originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
diags.diagnose(attr->getLocation(),
diag::derivative_attr_not_in_same_file_as_original);
return true;
}
// SWIFT_ENABLE_TENSORFLOW END

// Reject duplicate `@derivative` attributes.
auto &derivativeAttrs = Ctx.DerivativeAttrs[std::make_tuple(
originalAFD, resolvedDiffParamIndices, kind)];
Expand Down
5 changes: 0 additions & 5 deletions stdlib/cmake/modules/AddSwiftStdlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1733,11 +1733,6 @@ function(add_swift_target_library name)
endif()
endif()

# SWIFT_ENABLE_TENSORFLOW
list(APPEND swiftlib_swift_compile_flags_all
-Xfrontend -enable-experimental-cross-file-derivative-registration)
# SWIFT_ENABLE_TENSORFLOW END

# Collect architecture agnostic SDK linker flags
set(swiftlib_link_flags_all ${SWIFTLIB_LINK_FLAGS})
if(${sdk} STREQUAL IOS_SIMULATOR AND ${name} STREQUAL swiftMediaPlayer)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/a.swift -emit-module-path %t/a.swiftmodule
// SWIFT_ENABLE_TENSORFLOW
// Use `-enable-experimental-cross-file-derivative-registration` flag. To be removed soon.
// RUN: %target-swift-frontend -enable-experimental-cross-file-derivative-registration -emit-module -primary-file %S/Inputs/b.swift -emit-module-path %t/b.swiftmodule -I %t
// SWIFT_ENABLE_TENSORFLOW END
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/b.swift -emit-module-path %t/b.swiftmodule -I %t
// "-verify-ignore-unknown" is for "<unknown>:0: note: 'init()' declared here"
// RUN: %target-swift-frontend-typecheck -verify -verify-ignore-unknown -I %t %s

Expand Down
1 change: 0 additions & 1 deletion test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,6 @@ extension InoutParameters {
// Test cross-file derivative registration.

extension FloatingPoint where Self: Differentiable {
// expected-error @+1 {{derivative not in the same file as the original function}}
@derivative(of: rounded)
func vjpRounded() -> (
value: Self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// cross_module_derivative_attr_e2e.swift.

// RUN: %empty-directory(%t)
// RUN: %target-build-swift -Xfrontend -enable-experimental-cross-file-derivative-registration -parse-as-library -emit-module -module-name MultiFileModule -emit-module-path %t/MultiFileModule.swiftmodule -emit-library -o %t/%target-library-name(MultiFileModule) %S/Inputs/always_emit_into_client/MultiFileModule/file1.swift %S/Inputs/always_emit_into_client/MultiFileModule/file2.swift
// RUN: %target-build-swift -parse-as-library -emit-module -module-name MultiFileModule -emit-module-path %t/MultiFileModule.swiftmodule -emit-library -o %t/%target-library-name(MultiFileModule) %S/Inputs/always_emit_into_client/MultiFileModule/file1.swift %S/Inputs/always_emit_into_client/MultiFileModule/file2.swift
// RUN: not %target-build-swift -I%t -L%t %s -o %t/a.out -lm -lMultiFileModule 2>&1 | %FileCheck %s

import StdlibUnittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// cross_module_derivative_attr_e2e.swift.

// RUN: %empty-directory(%t)
// RUN: %target-build-swift -Xfrontend -enable-experimental-cross-file-derivative-registration -parse-as-library -emit-module -module-name SingleFileModule -emit-module-path %t/SingleFileModule.swiftmodule -emit-library -o %t/%target-library-name(SingleFileModule) %S/Inputs/always_emit_into_client/SingleFileModule/file.swift
// RUN: %target-build-swift -Xfrontend -parse-as-library -emit-module -module-name SingleFileModule -emit-module-path %t/SingleFileModule.swiftmodule -emit-library -o %t/%target-library-name(SingleFileModule) %S/Inputs/always_emit_into_client/SingleFileModule/file.swift
// RUN: not %target-build-swift -I%t -L%t %s -o %t/a.out -lm -lSingleFileModule 2>&1 | %FileCheck %s

import StdlibUnittest
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %empty-directory(%t)
// RUN: %target-build-swift -working-directory %t -I%t -parse-as-library -emit-module -module-name module1 -emit-module-path %t/module1.swiftmodule -emit-library -static %S/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift %S/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift -Xfrontend -enable-experimental-cross-file-derivative-registration -Xfrontend -validate-tbd-against-ir=none
// RUN: %target-build-swift -I%t -L%t %S/Inputs/cross_module_derivative_attr_e2e/main/main.swift -o %t/a.out -lm -lmodule1 -Xfrontend -enable-experimental-cross-file-derivative-registration -Xfrontend -validate-tbd-against-ir=none
// RUN: %target-build-swift -working-directory %t -I%t -parse-as-library -emit-module -module-name module1 -emit-module-path %t/module1.swiftmodule -emit-library -static %S/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift %S/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift -Xfrontend -validate-tbd-against-ir=none
// RUN: %target-build-swift -I%t -L%t %S/Inputs/cross_module_derivative_attr_e2e/main/main.swift -o %t/a.out -lm -lmodule1 -Xfrontend -validate-tbd-against-ir=none
// RUN: %target-run %t/a.out
// REQUIRES: executable_test
4 changes: 1 addition & 3 deletions test/AutoDiff/downstream/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,8 @@ func vjpFooExtraGenericRequirements<T : FloatingPoint & Differentiable & BinaryI
return (x, { $0 })
}

// Test cross-file derivative registration. Currently unsupported.
// TODO(TF-1021): Lift this restriction.
// Test cross-file derivative registration.
extension FloatingPoint where Self: Differentiable {
// expected-error @+1 {{derivative not in the same file as the original function}}
@derivative(of: rounded)
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
fatalError()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %empty-directory(%t)
// RUN: %clang -c %S/Inputs/Foreign.c -fmodules -o %t/CForeign.o
// RUN: %target-swift-emit-silgen -enable-experimental-cross-file-derivative-registration -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SILGEN --check-prefix=CHECK
// RUN: %target-swift-emit-sil -enable-experimental-cross-file-derivative-registration -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SIL --check-prefix=CHECK
// RUN: %target-build-swift -Xfrontend -enable-experimental-cross-file-derivative-registration -I %S/Inputs -I %t %s %t/CForeign.o
// RUN: %target-swift-emit-silgen -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SILGEN --check-prefix=CHECK
// RUN: %target-swift-emit-sil -I %S/Inputs -I %t %s | %FileCheck %s --check-prefix=CHECK-SIL --check-prefix=CHECK
// RUN: %target-build-swift -I %S/Inputs -I %t %s %t/CForeign.o

import CForeign

Expand Down