Skip to content

Commit 49feb45

Browse files
author
marcrasi
authored
1 parent d214a91 commit 49feb45

File tree

2 files changed

+103
-29
lines changed

2 files changed

+103
-29
lines changed

include/swift/SIL/SILWitnessVisitor.h

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,16 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
120120

121121
void visitAbstractStorageDecl(AbstractStorageDecl *sd) {
122122
sd->visitOpaqueAccessors([&](AccessorDecl *accessor) {
123-
if (SILDeclRef::requiresNewWitnessTableEntry(accessor))
124-
asDerived().addMethod(SILDeclRef(accessor, SILDeclRef::Kind::Func));
123+
// SWIFT_ENABLE_TENSORFLOW
124+
addMethodAndAutoDiffAssociatedMethodsIfRequired(accessor,
125+
SILDeclRef::Kind::Func);
125126
});
126127
}
127128

128129
void visitConstructorDecl(ConstructorDecl *cd) {
129-
if (SILDeclRef::requiresNewWitnessTableEntry(cd))
130-
asDerived().addMethod(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
130+
// SWIFT_ENABLE_TENSORFLOW
131+
addMethodAndAutoDiffAssociatedMethodsIfRequired(
132+
cd, SILDeclRef::Kind::Allocator);
131133
}
132134

133135
void visitAccessorDecl(AccessorDecl *func) {
@@ -137,22 +139,8 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
137139
void visitFuncDecl(FuncDecl *func) {
138140
assert(!isa<AccessorDecl>(func));
139141
// SWIFT_ENABLE_TENSORFLOW
140-
if (!SILDeclRef::requiresNewWitnessTableEntry(func))
141-
return;
142-
143-
auto funcDeclRef = SILDeclRef(func, SILDeclRef::Kind::Func);
144-
asDerived().addMethod(funcDeclRef);
145-
146-
if (auto *DA = func->getAttrs().getAttribute<DifferentiableAttr>()) {
147-
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
148-
AutoDiffAssociatedFunctionIdentifier::get(
149-
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
150-
DA->getParameterIndices(), func->getASTContext())));
151-
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
152-
AutoDiffAssociatedFunctionIdentifier::get(
153-
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,
154-
DA->getParameterIndices(), func->getASTContext())));
155-
}
142+
addMethodAndAutoDiffAssociatedMethodsIfRequired(func,
143+
SILDeclRef::Kind::Func);
156144
}
157145

158146
void visitMissingMemberDecl(MissingMemberDecl *placeholder) {
@@ -179,6 +167,28 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
179167
void visitPoundDiagnosticDecl(PoundDiagnosticDecl *pdd) {
180168
// We don't care about diagnostics at this stage.
181169
}
170+
171+
// SWIFT_ENABLE_TENSORFLOW
172+
private:
173+
void addMethodAndAutoDiffAssociatedMethodsIfRequired(
174+
AbstractFunctionDecl *func, SILDeclRef::Kind kind) {
175+
if (!SILDeclRef::requiresNewWitnessTableEntry(func))
176+
return;
177+
178+
auto funcDeclRef = SILDeclRef(func, kind);
179+
asDerived().addMethod(funcDeclRef);
180+
181+
if (auto *DA = func->getAttrs().getAttribute<DifferentiableAttr>()) {
182+
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
183+
AutoDiffAssociatedFunctionIdentifier::get(
184+
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
185+
DA->getParameterIndices(), func->getASTContext())));
186+
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
187+
AutoDiffAssociatedFunctionIdentifier::get(
188+
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,
189+
DA->getParameterIndices(), func->getASTContext())));
190+
}
191+
}
182192
};
183193

184194
} // end namespace swift

test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ import StdlibUnittest
44

55
var ProtocolRequirementAutodiffTests = TestSuite("ProtocolRequirementAutodiff")
66

7+
// MARK: - Func requirements.
8+
79
protocol DiffReq : Differentiable {
810
@differentiable(wrt: (self, x))
911
func f(_ x: Float) -> Float
1012
}
1113

1214
extension DiffReq where TangentVector : AdditiveArithmetic {
15+
@inline(never) // Prevent specialization, to test all witness code.
1316
func gradF(at x: Float) -> (Self.TangentVector, Float) {
1417
return (valueWithPullback(at: x) { s, x in s.f(x) }).1(1)
1518
}
@@ -53,7 +56,76 @@ extension Quadratic : VectorProtocol {
5356
}
5457
}
5558

56-
// Test witness method SIL type computation.
59+
ProtocolRequirementAutodiffTests.test("func") {
60+
expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
61+
expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
62+
Quadratic(11, 12, 13).gradF(at: 1))
63+
expectEqual((Quadratic(4, 2, 1), 2 * 11 * 2 + 12),
64+
Quadratic(11, 12, 13).gradF(at: 2))
65+
}
66+
67+
// MARK: Constructor, accessor, and subscript requirements.
68+
69+
protocol FunctionsOfX: Differentiable {
70+
@differentiable
71+
init(x: Float)
72+
73+
@differentiable
74+
var x: Float { get }
75+
76+
@differentiable
77+
var y: Float { get }
78+
79+
@differentiable
80+
var z: Float { get }
81+
82+
@differentiable
83+
subscript() -> Float { get }
84+
}
85+
86+
struct TestFunctionsOfX: FunctionsOfX {
87+
@differentiable
88+
init(x: Float) {
89+
self.x = x
90+
self.y = x * x
91+
}
92+
93+
/// x = x
94+
var x: Float
95+
96+
/// y = x * x
97+
var y: Float
98+
99+
/// z = x * x + x
100+
var z: Float {
101+
return y + x
102+
}
103+
104+
@differentiable
105+
subscript() -> Float {
106+
return z
107+
}
108+
}
109+
110+
@inline(never) // Prevent specialization, to test all witness code.
111+
func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
112+
-> (Float, Float, Float, Float)
113+
{
114+
let dxdx = gradient(at: x) { x in F(x: x).x }
115+
let dydx = gradient(at: x) { x in F(x: x).y }
116+
let dzdx = gradient(at: x) { x in F(x: x).z }
117+
let dsubscriptdx = gradient(at: x) { x in F(x: x)[] }
118+
return (dxdx, dydx, dzdx, dsubscriptdx)
119+
}
120+
121+
ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") {
122+
expectEqual(
123+
derivatives(at: 2.0, in: TestFunctionsOfX.self),
124+
(1.0, 4.0, 5.0, 5.0))
125+
}
126+
127+
// MARK: - Test witness method SIL type computation.
128+
57129
protocol P : Differentiable {
58130
@differentiable(wrt: (x, y))
59131
func foo(_ x: Float, _ y: Double) -> Float
@@ -65,12 +137,4 @@ struct S : P {
65137
}
66138
}
67139

68-
ProtocolRequirementAutodiffTests.test("Trivial") {
69-
expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
70-
expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
71-
Quadratic(11, 12, 13).gradF(at: 1))
72-
expectEqual((Quadratic(4, 2, 1), 2 * 11 * 2 + 12),
73-
Quadratic(11, 12, 13).gradF(at: 2))
74-
}
75-
76140
runAllTests()

0 commit comments

Comments
 (0)