Skip to content

Commit 2d2a5de

Browse files
committed
[AutoDiff upstream] Minor cleanup for @differentiable attribute.
Add TODO comments referencing JIRA issues. Add disabled `@differentiable` attribute serialization test. TF-836 tracks enabling the test. Blocked by TF-828: `@differentiable` attribute type-checking.
1 parent 23edd22 commit 2d2a5de

File tree

4 files changed

+143
-1
lines changed

4 files changed

+143
-1
lines changed

lib/AST/Attr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,13 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
857857
Printer << ")";
858858
break;
859859

860+
case DAK_Differentiable: {
861+
Printer.printAttrName("@differentiable");
862+
auto *attr = cast<DifferentiableAttr>(this);
863+
printDifferentiableAttrArguments(attr, Printer, Options, D);
864+
break;
865+
}
866+
860867
case DAK_Count:
861868
llvm_unreachable("exceed declaration attribute kinds");
862869

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
112112
IGNORED_ATTR(ProjectedValueProperty)
113113
IGNORED_ATTR(ReferenceOwnership)
114114

115-
// TODO: Changes are yet to be upstreamed from apple/tensorflow branch.
115+
// TODO(TF-828): Upstream `@differentiable` attribute type-checking from
116+
// tensorflow branch.
116117
IGNORED_ATTR(Differentiable)
117118
#undef IGNORED_ATTR
118119

lib/Serialization/Serialization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,6 +2291,11 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
22912291
vjpRef = S.addDeclRef(vjpFunction);
22922292

22932293
auto paramIndices = attr->getParameterIndices();
2294+
// TODO(TF-837): Implement `@differentiable` attribute serialization.
2295+
// Blocked by TF-828: `@differentiable` attribute type-checking, which
2296+
// resolves parameter indices (`IndexSubset *`).
2297+
if (!paramIndices)
2298+
return;
22942299
assert(paramIndices && "Checked parameter indices must be resolved");
22952300
SmallVector<bool, 4> indices;
22962301
for (unsigned i : range(paramIndices->getCapacity()))
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
3+
// RUN: %empty-directory(%t)
4+
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
5+
// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
6+
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
7+
// REQUIRES: differentiable_programming
8+
9+
// TODO(TF-836): Enable this test.
10+
// Blocked by TF-828: `@differentiating` attribute type-checking.
11+
// XFAIL: *
12+
13+
// BCANALYZER-NOT: UnknownCode
14+
15+
import _Differentiation
16+
17+
// CHECK: @differentiable(wrt: x, jvp: jvpSimple, vjp: vjpSimple)
18+
// CHECK-NEXT: func simple(x: Float) -> Float
19+
@differentiable(jvp: jvpSimple, vjp: vjpSimple)
20+
func simple(x: Float) -> Float {
21+
return x
22+
}
23+
24+
// CHECK: @differentiable(linear, wrt: x)
25+
// CHECK-NEXT: func simple2(x: Float) -> Float
26+
@differentiable(linear)
27+
func simple2(x: Float) -> Float {
28+
return x
29+
}
30+
31+
// CHECK: @differentiable(linear, wrt: x)
32+
// CHECK-NEXT: func simple4(x: Float) -> Float
33+
@differentiable(linear, wrt: x)
34+
func simple4(x: Float) -> Float {
35+
return x
36+
}
37+
38+
func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
39+
return (x, { v in v })
40+
}
41+
42+
func vjpSimple(x: Float) -> (Float, (Float) -> Float) {
43+
return (x, { v in v })
44+
}
45+
46+
// CHECK: @differentiable(wrt: x)
47+
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
48+
@differentiable(wrt: x)
49+
func testWrtClause(x: Float, y: Float) -> Float {
50+
return x + y
51+
}
52+
53+
struct InstanceMethod : Differentiable {
54+
// CHECK: @differentiable(wrt: (self, y))
55+
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
56+
@differentiable(wrt: (self, y))
57+
func testWrtClause(x: Float, y: Float) -> Float {
58+
return x + y
59+
}
60+
61+
struct TangentVector: Differentiable, AdditiveArithmetic {
62+
typealias TangentVector = Self
63+
static func ==(_: Self, _: Self) -> Bool { fatalError() }
64+
static var zero: Self { fatalError() }
65+
static func +(_: Self, _: Self) -> Self { fatalError() }
66+
static func -(_: Self, _: Self) -> Self { fatalError() }
67+
}
68+
mutating func move(along direction: TangentVector) {}
69+
}
70+
71+
// CHECK: @differentiable(wrt: x where T : Differentiable)
72+
// CHECK-NEXT: func testOnlyWhereClause<T>(x: T) -> T where T : Numeric
73+
@differentiable(where T : Differentiable)
74+
func testOnlyWhereClause<T : Numeric>(x: T) -> T {
75+
return x
76+
}
77+
78+
// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClause where T : Differentiable)
79+
// CHECK-NEXT: func testWhereClause<T>(x: T) -> T where T : Numeric
80+
@differentiable(vjp: vjpTestWhereClause where T : Differentiable)
81+
func testWhereClause<T : Numeric>(x: T) -> T {
82+
return x
83+
}
84+
func vjpTestWhereClause<T>(x: T) -> (T, (T.TangentVector) -> T.TangentVector)
85+
where T : Numeric, T : Differentiable
86+
{
87+
return (x, { v in v })
88+
}
89+
90+
protocol P {}
91+
extension P {
92+
// CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethod where Self : Differentiable)
93+
// CHECK-NEXT: func testWhereClauseMethod() -> Self
94+
@differentiable(wrt: self, vjp: vjpTestWhereClauseMethod where Self : Differentiable)
95+
func testWhereClauseMethod() -> Self {
96+
return self
97+
}
98+
}
99+
extension P where Self : Differentiable {
100+
func vjpTestWhereClauseMethod() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
101+
return (self, { v in v })
102+
}
103+
}
104+
105+
// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector)
106+
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint<T>(x: T) -> T where T : Numeric
107+
@differentiable(vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector)
108+
func testWhereClauseMethodTypeConstraint<T : Numeric>(x: T) -> T {
109+
return x
110+
}
111+
func vjpTestWhereClauseMethodTypeConstraint<T>(x: T) -> (T, (T) -> T)
112+
where T : Numeric, T : Differentiable, T == T.TangentVector
113+
{
114+
return (x, { v in v })
115+
}
116+
117+
extension P {
118+
// CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self : Differentiable, Self == Self.TangentVector)
119+
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self
120+
@differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self.TangentVector == Self, Self : Differentiable)
121+
func testWhereClauseMethodTypeConstraint() -> Self {
122+
return self
123+
}
124+
}
125+
extension P where Self : Differentiable, Self == Self.TangentVector {
126+
func vjpTestWhereClauseMethodTypeConstraint() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
127+
return (self, { v in v })
128+
}
129+
}

0 commit comments

Comments
 (0)