Skip to content

Commit 92e9221

Browse files
committed
Merge branch 'tensorflow' of github.com:apple/swift into tensorflow-merge
2 parents 563c52f + 79cdb96 commit 92e9221

File tree

4 files changed

+59
-46
lines changed

4 files changed

+59
-46
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
826826
expectedJVPType);
827827
}
828828
silDiffAttr->setJVPName(jvpThunk->getName());
829-
// Unset JVP so that TBDGen triggers.
830-
diffAttr->setJVPFunction(nullptr);
831829
}
832830
// Thunk VJP method, if it is defined.
833831
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
@@ -845,8 +843,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
845843
expectedVJPType);
846844
}
847845
silDiffAttr->setVJPName(vjpThunk->getName());
848-
// Unset VJP so that TBDGen triggers.
849-
diffAttr->setVJPFunction(nullptr);
850846
}
851847
}
852848
}

lib/SILGen/SILGenType.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -88,36 +88,11 @@ SILGenModule::emitVTableMethod(ClassDecl *theClass,
8888
implFn = getDynamicThunk(derived, Types.getConstantInfo(derived).SILFnType);
8989
// SWIFT_ENABLE_TENSORFLOW
9090
} else if (auto *adafi = derived.autoDiffAssociatedFunctionIdentifier) {
91-
auto *decl = derived.getDecl();
92-
auto *DA = *llvm::find_if(
93-
decl->getAttrs().getAttributes<DifferentiableAttr>(),
94-
[&](const DifferentiableAttr *attr) {
95-
return attr->getParameterIndices() == adafi->getParameterIndices();
96-
});
97-
assert(DA && "Expected `@differentiable` attribute");
98-
// Get autodiff associated function declaration, if it exists.
99-
FuncDecl *assocDecl = nullptr;
100-
switch (adafi->getKind()) {
101-
case AutoDiffAssociatedFunctionKind::JVP:
102-
assocDecl = DA->getJVPFunction();
103-
break;
104-
case AutoDiffAssociatedFunctionKind::VJP:
105-
assocDecl = DA->getVJPFunction();
106-
break;
107-
}
108-
// If declaration exists, get corresponding SIL function.
109-
if (assocDecl) {
110-
SILDeclRef assocRef(assocDecl, SILDeclRef::Kind::Func);
111-
implFn = getFunction(assocRef, NotForDefinition);
112-
}
113-
// Otherwise, create an autodiff vtable entry thunk. The thunk contains an
114-
// `autodiff_function` instruction, which is later filled during
91+
// For JVP/VJP methods, create a vtable entry thunk. The thunk contains an
92+
// `autodiff_function` instruction, which is later filled during the
11593
// differentiation transform.
116-
// TODO(TF-524): Generalize canonical JVP/VJP thunk generation.
117-
else {
118-
implFn = getOrCreateAutoDiffClassMethodThunk(
119-
derived, Types.getConstantInfo(derived).SILFnType);
120-
}
94+
implFn = getOrCreateAutoDiffClassMethodThunk(
95+
derived, Types.getConstantInfo(derived).SILFnType);
12196
// SWIFT_ENABLE_TENSORFLOW END
12297
} else {
12398
implFn = getFunction(derived, NotForDefinition);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
public struct Wrapper : Differentiable, AdditiveArithmetic {
2+
public var x: Float
3+
public init(_ x: Float) {
4+
self.x = x
5+
}
6+
7+
public static func + (lhs: Wrapper, rhs: Wrapper) -> Wrapper {
8+
return Wrapper(lhs.x + rhs.x)
9+
}
10+
11+
@differentiating(+)
12+
public static func vjpAdd(lhs: Wrapper, rhs: Wrapper)
13+
-> (value: Wrapper, pullback: (Wrapper) -> (Wrapper, Wrapper)) {
14+
return (lhs + rhs, { v in (v, v) })
15+
}
16+
17+
public static func * (lhs: Wrapper, rhs: Wrapper) -> Wrapper {
18+
return Wrapper(lhs.x * rhs.x)
19+
}
20+
21+
@differentiating(*)
22+
public static func jvpMultiply(lhs: Wrapper, rhs: Wrapper)
23+
-> (value: Wrapper, differential: (Wrapper, Wrapper) -> Wrapper) {
24+
return (lhs * rhs, { dlhs, drhs in dlhs * rhs + lhs * drhs })
25+
}
26+
27+
@differentiating(*)
28+
public static func vjpMultiply(lhs: Wrapper, rhs: Wrapper)
29+
-> (value: Wrapper, pullback: (Wrapper) -> (Wrapper, Wrapper)) {
30+
return (lhs * rhs, { v in (v * rhs, v * lhs) })
31+
}
32+
}

test/AutoDiff/differentiable_attr_silgen_cross_module.swift

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
// RUN: %target-swift-frontend -emit-silgen -verify %s | %FileCheck %s -check-prefix=CHECK-SILGEN
2-
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -emit-module -primary-file %S/Inputs/differentiable_attr_silgen_other_module.swift -emit-module-path %t/differentiable_attr_silgen_other_module.swiftmodule
3+
// RUN: %target-swift-frontend -emit-silgen -verify -I %t -primary-file %s | %FileCheck %s -check-prefix=CHECK-SILGEN
4+
// RUN: %target-swift-frontend -emit-sil -verify -I %t -primary-file %s | %FileCheck %s -check-prefix=CHECK-SIL
35

4-
// After SILGen, SIL `[differentiable]` should have jvp/vjp names only if the AST `@differentiable` attribute does.
5-
// The differentiation pass is guaranteed to fill in jvp/vjp names.
6+
import differentiable_attr_silgen_other_module
67

7-
_ = gradient(at: Float(1)) { x in x + x * x }
8+
// After SILGen, a SIL `[differentiable]` attribute on a function from the
9+
// current module should have JVP/VJP names only if the AST `@differentiable`
10+
// attribute does.
811

9-
// CHECK-SILGEN-LABEL: // static Float.* infix(_:_:)
10-
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
11-
// CHECK-SIL-LABEL: // static Float.* infix(_:_:)
12-
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
12+
// For external functions, `[differentiable]` attribute JVP/VJP names should
13+
// always exist. The differentiation pass is guaranteed to fill in
14+
// `[differentiable]` attribute JVP/VJP names.
15+
16+
_ = pullback(at: Wrapper(1)) { x in x + x * x }
17+
18+
// CHECK-SILGEN-LABEL: // static Wrapper.* infix(_:_:)
19+
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
20+
// CHECK-SIL-LABEL: // static Wrapper.* infix(_:_:)
21+
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1moiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
22+
23+
// CHECK-SILGEN-LABEL: // static Wrapper.+ infix(_:_:)
24+
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
25+
// CHECK-SIL-LABEL: // static Wrapper.+ infix(_:_:)
26+
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
1327

14-
// CHECK-SILGEN-LABEL: // static Float.+ infix(_:_:)
15-
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
16-
// CHECK-SIL-LABEL: // static Float.+ infix(_:_:)
17-
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float

0 commit comments

Comments
 (0)