Skip to content

Commit 3cc1868

Browse files
authored
---
yaml --- r: 340965 b: refs/heads/rxwei-patch-1 c: 58e5175 h: refs/heads/master i: 340963: cc112c8
1 parent 0914a50 commit 3cc1868

File tree

7 files changed

+83
-14
lines changed

7 files changed

+83
-14
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: 2e8ef909a086639bc6dcbac46325b4211845ae29
1018+
refs/heads/rxwei-patch-1: 58e5175cf6af8e45d73d832d3c6a01840b40f07c
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/lib/SIL/SILFunctionBuilder.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,42 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
8080
!constant.isStoredPropertyInitializer() &&
8181
!constant.isThunk()) {
8282
for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) {
83-
std::string jvpName, vjpName;
84-
// Get JVP/VJP names.
85-
if (auto *jvpFn = A->getJVPFunction())
86-
jvpName = SILDeclRef(jvpFn).mangle();
87-
if (auto *vjpFn = A->getVJPFunction())
88-
vjpName = SILDeclRef(vjpFn).mangle();
8983
// Get lowered argument indices.
90-
auto paramIndices = A->getParameterIndices();
84+
auto *paramIndices = A->getParameterIndices();
9185
// NOTE: If `A->getParameterIndices()` is `nullptr`, continue. This is a
9286
// necessary hack regarding deserialization.
9387
if (!paramIndices)
9488
continue;
95-
auto loweredParamIndices = paramIndices->getLowered(
89+
auto *loweredParamIndices = paramIndices->getLowered(
9690
F->getASTContext(),
9791
decl->getInterfaceType()->castTo<AnyFunctionType>());
9892
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
99-
auto silDiffAttr = SILDifferentiableAttr::create(
93+
// Get JVP/VJP names.
94+
std::string jvpName, vjpName;
95+
// If a method-self-reordering thunk is generated for the original
96+
// function, use mangled JVP/VJP symbols.
97+
auto *AFD = constant.getAbstractFunctionDecl();
98+
auto selfParamIndex =
99+
F->getLoweredFunctionType()->getNumParameters() - 1;
100+
if (AFD && AFD->isInstanceMember() &&
101+
F->getLoweredFunctionType()->hasSelfParam() &&
102+
indices.isWrtParameter(selfParamIndex) &&
103+
indices.parameters->getNumIndices() > 1) {
104+
auto &ctx = F->getASTContext();
105+
if (A->getJVPFunction())
106+
jvpName = ctx.getIdentifier(
107+
"AD__" + constant.mangle() + "__jvp_" + indices.mangle()).str();
108+
if (A->getVJPFunction()) {
109+
vjpName = ctx.getIdentifier(
110+
"AD__" + constant.mangle() + "__vjp_" + indices.mangle()).str();
111+
}
112+
} else {
113+
if (auto *jvpFn = A->getJVPFunction())
114+
jvpName = SILDeclRef(jvpFn).mangle();
115+
if (auto *vjpFn = A->getVJPFunction())
116+
vjpName = SILDeclRef(vjpFn).mangle();
117+
}
118+
auto *silDiffAttr = SILDifferentiableAttr::create(
100119
M, indices, A->getRequirements(), M.allocateCopy(jvpName),
101120
M.allocateCopy(vjpName));
102121
#ifndef NDEBUG

branches/rxwei-patch-1/lib/SILGen/SILGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
781781
LLVM_DEBUG(llvm::dbgs() << "lowered sil:\n";
782782
F->print(llvm::dbgs()));
783783

784+
// SWIFT_ENABLE_TENSORFLOW
784785
// Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
785786
if (constant.hasDecl()) {
786787
auto *AFD = constant.getAbstractFunctionDecl();

branches/rxwei-patch-1/lib/TBDGen/TBDGen.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,20 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
219219
// `@differentiable` attribute.
220220
auto diffAttrs = AFD->getAttrs().getAttributes<DifferentiableAttr>();
221221
for (auto *DA : diffAttrs) {
222+
// If a method-self-reordering thunk is generated for the original function,
223+
// emit symbol.
224+
auto isSelfReorderedMethod = AFD && AFD->isInstanceMember() &&
225+
AFD->hasImplicitSelfDecl() &&
226+
DA->getParameterIndices()->parameters.count() > 1;
222227
// FIXME: When we get rid of `vjp:` and `jvp:` arguments in `@differentiable`,
223228
// we will no longer need to see whether they are specified.
224-
if (!DA->getJVPFunction()) {
229+
if (!DA->getJVPFunction() || isSelfReorderedMethod) {
225230
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
226231
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
227232
DA->getParameterIndices(), AFD->getASTContext());
228233
addSymbol(SILDeclRef(AFD).asAutoDiffAssociatedFunction(id));
229234
}
230-
if (!DA->getVJPFunction()) {
235+
if (!DA->getVJPFunction() || isSelfReorderedMethod) {
231236
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
232237
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,
233238
DA->getParameterIndices(), AFD->getASTContext());
@@ -304,15 +309,21 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) {
304309
// with a `@differentiable` attribute.
305310
auto diffAttrs = ASD->getAttrs().getAttributes<DifferentiableAttr>();
306311
for (auto *DA : diffAttrs) {
312+
// If a method-self-reordering thunk is generated for the original function,
313+
// emit symbol.
314+
auto isSelfReorderedMethod = ASD->getGetter() &&
315+
ASD->getGetter()->isInstanceMember() &&
316+
ASD->getGetter()->hasImplicitSelfDecl() &&
317+
DA->getParameterIndices()->parameters.count() > 1;
307318
// FIXME: When we get rid of `vjp:` and `jvp:` arguments in `@differentiable`,
308319
// we will no longer need to see whether they are specified.
309-
if (!DA->getJVPFunction()) {
320+
if (!DA->getJVPFunction() || isSelfReorderedMethod) {
310321
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
311322
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
312323
DA->getParameterIndices(), ASD->getASTContext());
313324
addSymbol(SILDeclRef(ASD->getGetter()).asAutoDiffAssociatedFunction(id));
314325
}
315-
if (!DA->getVJPFunction()) {
326+
if (!DA->getVJPFunction() || isSelfReorderedMethod) {
316327
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
317328
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,
318329
DA->getParameterIndices(), ASD->getASTContext());
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
struct TF_619: Differentiable {
2+
var p: Float = 1
3+
4+
@differentiable(vjp: vjpFoo)
5+
func foo(_ x: Float) -> Float {
6+
return p * x
7+
}
8+
9+
func vjpFoo(_ x: Float) -> (Float, (Float) -> (TangentVector, Float)) {
10+
return (x, { v in (TangentVector(p: v * x), v * self.p) })
11+
}
12+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift %S/../Inputs/method_self_reordering_thunk_other_module.swift %s -o %t/a.out
3+
// RUN: %target-codesign %t/a.out
4+
// RUN: %target-run %t/a.out
5+
6+
// REQUIRES: executable_test
7+
8+
import StdlibUnittest
9+
10+
var MethodSelfReorderingThunkTests = TestSuite("MethodSelfReorderingThunks")
11+
12+
// Test TF-619: cross-module import of `@differentiable` methods with
13+
// self-ordering thunks.
14+
MethodSelfReorderingThunkTests.test("CrossModuleMethodSelfReorderingThunk") {
15+
expectEqual(1, gradient(at: 0) { x in TF_619().foo(x) })
16+
}
17+
18+
runAllTests()

branches/rxwei-patch-1/test/AutoDiff/tbdgen.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ public extension Float {
4242
}
4343

4444
// This should generate public symbols for both JVP and VJP.
45+
// Tests self-reordering-method thunking.
46+
@differentiable
47+
func method(x: Float, y: Float) -> Float {
48+
return x
49+
}
50+
51+
// This should generate public symbols for both JVP and VJP.
52+
// Tests self-reordering-method thunking.
4553
@differentiable
4654
subscript(x: Float) -> Float {
4755
return x

0 commit comments

Comments
 (0)