Skip to content

Commit b4a8aac

Browse files
author
Marc Rasi
committed
[AutoDiff] flag for cross-file derivative registration
1 parent 7d3ae09 commit b4a8aac

File tree

5 files changed

+31
-1
lines changed

5 files changed

+31
-1
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,16 @@
3535
#include "swift/Sema/IDETypeChecking.h"
3636
#include "clang/Basic/CharInfo.h"
3737
#include "llvm/Support/Debug.h"
38+
// SWIFT_ENABLE_TENSORFLOW
39+
#include "llvm/Support/Options.h"
3840

3941
using namespace swift;
4042

43+
// SWIFT_ENABLE_TENSORFLOW
44+
static llvm::cl::opt<bool> EnableExperimentalCrossFileDerivativeRegistration(
45+
"enable-experimental-cross-file-derivative-registration",
46+
llvm::cl::init(false));
47+
4148
namespace {
4249
/// This emits a diagnostic with a fixit to remove the attribute.
4350
template<typename ...ArgTypes>
@@ -3646,7 +3653,8 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36463653

36473654
// Reject different-file derivative registration.
36483655
// TODO(TF-1021): Lift this restriction.
3649-
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
3656+
if (!EnableExperimentalCrossFileDerivativeRegistration &&
3657+
originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
36503658
diagnoseAndRemoveAttr(attr,
36513659
diag::derivative_attr_not_in_same_file_as_original);
36523660
return;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import StdlibUnittest
2+
3+
import module1
4+
5+
var Tests = TestSuite("CrossModuleDerivativeAttr")
6+
7+
Tests.test("CrossFile") {
8+
let grad = gradient(at: 0, in: fCrossFile)
9+
expectEqual(10, grad)
10+
}
11+
12+
runAllTests()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
public func fCrossFile(_ x: Float) -> Float { x }
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@derivative(of: fCrossFile)
2+
public func vjpCrossFile(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
3+
(x, { 10 * $0 })
4+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -I%t -c -parse-as-library -emit-module -module-name module1 -emit-module-path %t/module1.swiftmodule -o %t/module1.o %S/Inputs/cross_module_derivative_attr_e2e/module1/module1.swift %S/Inputs/cross_module_derivative_attr_e2e/module1/module1_other_file.swift -Xllvm -enable-experimental-cross-file-derivative-registration -validate-tbd-against-ir=none
3+
// RUN: %target-build-swift -I%t %S/Inputs/cross_module_derivative_attr_e2e/main/main.swift %t/module1.o -o %t/a.out -lm -Xllvm -enable-experimental-cross-file-derivative-registration -Xfrontend -validate-tbd-against-ir=none
4+
// RUN: %target-run %t/a.out
5+
// REQUIRES: executable_test

0 commit comments

Comments
 (0)