Skip to content

Commit a439d6f

Browse files
author
Marc Rasi
committed
Merge branch 'tensorflow' into tf-remove-differentiable-jvp-vjp
2 parents ab83c95 + a3c614f commit a439d6f

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3548,11 +3548,10 @@ resolveDifferentiableAttrOriginalFunction(DifferentiableAttr *attr) {
35483548
}
35493549
}
35503550
// Non-`get` accessors are not yet supported: `set`, `read`, and `modify`.
3551-
// TODO(TF-129): Enable `set` when differentiation supports inout parameters.
35523551
// TODO(TF-1080): Enable `read` and `modify` when differentiation supports
35533552
// coroutines.
35543553
if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
3555-
if (!accessor->isGetter())
3554+
if (!accessor->isGetter() && !accessor->isSetter())
35563555
original = nullptr;
35573556
// Diagnose if original `AbstractFunctionDecl` could not be resolved.
35583557
if (!original) {

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ struct DifferentiableInstanceMethod: Differentiable {
153153
}
154154

155155
// Test subscript methods.
156-
struct SubscriptMethod {
156+
struct SubscriptMethod: Differentiable {
157+
typealias TangentVector = DummyTangentVector
158+
mutating func move(along _: TangentVector) {}
159+
157160
@differentiable // ok
158161
subscript(implicitGetter x: Float) -> Float {
159162
return x
@@ -168,14 +171,14 @@ struct SubscriptMethod {
168171
subscript(explicit x: Float) -> Float {
169172
@differentiable // ok
170173
get { return x }
171-
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
174+
@differentiable // ok
172175
set {}
173176
}
174177

175178
subscript(x: Float, y: Float) -> Float {
176179
@differentiable // ok
177180
get { return x + y }
178-
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
181+
@differentiable // ok
179182
set {}
180183
}
181184
}
@@ -658,16 +661,15 @@ extension InoutParameters {
658661
mutating func mutatingMethod(_ other: Self) -> Self {}
659662
}
660663

661-
// Test unsupported accessors: `set`, `_read`, `_modify`.
664+
// Test accessors: `set`, `_read`, `_modify`.
662665

663-
struct UnsupportedAccessors: Differentiable {
666+
struct Accessors: Differentiable {
664667
typealias TangentVector = DummyTangentVector
665668
mutating func move(along _: TangentVector) {}
666669

667670
var stored: Float
668671
var computed: Float {
669672
// `set` has an `inout` parameter: `(inout Self) -> (Float) -> ()`.
670-
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
671673
@differentiable
672674
set { stored = newValue }
673675

test/AutoDiff/downstream/differentiable_attr_type_checking.swift

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ struct DifferentiableInstanceMethod : Differentiable {
133133
}
134134

135135
// Test subscript methods.
136-
struct SubscriptMethod {
136+
struct SubscriptMethod: Differentiable {
137137
@differentiable // ok
138138
subscript(implicitGetter x: Float) -> Float {
139139
return x
@@ -148,14 +148,14 @@ struct SubscriptMethod {
148148
subscript(explicit x: Float) -> Float {
149149
@differentiable // ok
150150
get { return x }
151-
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
151+
@differentiable // ok
152152
set {}
153153
}
154154

155155
subscript(x: Float, y: Float) -> Float {
156156
@differentiable // ok
157157
get { return x + y }
158-
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
158+
@differentiable // ok
159159
set {}
160160
}
161161
}
@@ -595,13 +595,11 @@ final class FinalClass: Differentiable {
595595
}
596596
}
597597

598-
// Test unsupported accessors: `set`, `_read`, `_modify`.
598+
// Test accessors: `set`, `_read`, `_modify`.
599599

600-
struct UnsupportedAccessors: Differentiable {
600+
struct Accessors: Differentiable {
601601
var stored: Float
602602
var computed: Float {
603-
// `set` has an `inout` parameter: `(inout Self) -> (Float) -> ()`.
604-
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
605603
@differentiable
606604
set { stored = newValue }
607605

0 commit comments

Comments
 (0)