Skip to content

Commit 982063d

Browse files
authored
[AutoDiff] Revamp 'struct_extract' differentiation strategy. (#25151)
Solution: - Ban user-defined derivatives for stored properties. This both prevents potential user-introduced errors and reduces the compiler implementation complexity. - When differentiating a stored property access, require that the TangentVector contains a stored property of the same name. Implementation: - Reject vjp: or jvp: in a @differentiable attribute on a stored property. - Make AdjointEmitter::visitStructExtractInst and AdjointEmitter::visitStructElementAddrInst differentiate fieldwise, forming a TangentVector using a struct instruction. - Eliminate @_fieldwiseDifferentiable everywhere. - Remove StructExtractDifferentiationStrategy and all related logic. - When a TangentVector does not contain a property with the requested name that corresponds to a property in the original type, emit a diagnostic property cannot be differentiated because 'Foo.TangentVector' does not have a member named 'x'. - Turn ADContext::emitNondifferentiabilityError methods into templates that take diagnostics with arbitrary arguments. This makes it possible for us to emit finer-grained diagnostics in the future.
1 parent ba4f0b4 commit 982063d

19 files changed

+248
-534
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,8 @@ DECL_ATTR(differentiating, Differentiating,
421421
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
422422
OnAccessor | OnFunc | OnConstructor | OnSubscript,
423423
/* Not serialized */ 90)
424-
SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable,
425-
OnNominalType | UserInaccessible, 91)
426424
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
427-
OnVar, 92)
425+
OnVar, 91)
428426

429427
#undef TYPE_ATTR
430428
#undef DECL_ATTR_ALIAS

include/swift/AST/DiagnosticsSIL.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ NOTE(autodiff_opaque_function_not_differentiable,none,
447447
"opaque non-'@differentiable' function is not differentiable", ())
448448
NOTE(autodiff_property_not_differentiable,none,
449449
"property is not differentiable", ())
450+
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
451+
"property cannot be differentiated because '%0.TangentVector' does not "
452+
"have a member named '%1'", (StringRef, StringRef))
450453
NOTE(autodiff_value_defined_here,none,
451454
"value defined here", ())
452455
NOTE(autodiff_when_differentiating_function_call,none,

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none,
27332733
"layout requirement are not supported by '@differentiable' attribute", ())
27342734
ERROR(differentiable_attr_class_unsupported,none,
27352735
"class members cannot be marked with '@differentiable'", ())
2736+
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
2737+
"'jvp:' or 'vjp:' cannot be specified for stored properties", ())
27362738
NOTE(protocol_witness_missing_specific_differentiable_attr,none,
27372739
"candidate is missing attribute '%0'", (StringRef))
27382740

@@ -2806,11 +2808,6 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
28062808
"'@noDerivative' is only allowed on stored properties in structure types "
28072809
"that declare a conformance to 'Differentiable'", ())
28082810

2809-
// @_fieldwiseDifferentiable attribute
2810-
ERROR(fieldwise_differentiable_only_on_differentiable_structs,none,
2811-
"'@_fieldwiseDifferentiable' is only allowed on structure types that "
2812-
"conform to 'Differentiable'", ())
2813-
28142811
//------------------------------------------------------------------------------
28152812
// MARK: Type Check Expressions
28162813
//------------------------------------------------------------------------------

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 132 additions & 358 deletions
Large diffs are not rendered by default.

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,6 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
690690
auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(),
691691
/*Inherited*/ C.AllocateCopy(inherited),
692692
/*GenericParams*/ {}, parentDC);
693-
structDecl->getAttrs().add(
694-
new (C) FieldwiseDifferentiableAttr(/*implicit*/ true));
695693
structDecl->setImplicit();
696694
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
697695

@@ -960,12 +958,6 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
960958
if (!getAssociatedType(member, parentDC, id))
961959
return nullptr;
962960

963-
// Since associated types will be derived, we make this struct a fieldwise
964-
// differentiable type.
965-
if (!nominal->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>())
966-
nominal->getAttrs().add(
967-
new (C) FieldwiseDifferentiableAttr(/*implicit*/ true));
968-
969961
// Prevent re-synthesis during repeated calls.
970962
// FIXME: Investigate why this is necessary to prevent duplicate synthesis.
971963
auto lookup = nominal->lookupDirect(id);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ class AttributeEarlyChecker : public AttributeVisitor<AttributeEarlyChecker> {
135135
IGNORED_ATTR(Differentiable)
136136
IGNORED_ATTR(Differentiating)
137137
IGNORED_ATTR(CompilerEvaluable)
138-
IGNORED_ATTR(FieldwiseDifferentiable)
139138
IGNORED_ATTR(NoDerivative)
140139
#undef IGNORED_ATTR
141140

@@ -872,7 +871,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
872871
void visitDifferentiableAttr(DifferentiableAttr *attr);
873872
void visitDifferentiatingAttr(DifferentiatingAttr *attr);
874873
void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr);
875-
void visitFieldwiseDifferentiableAttr(FieldwiseDifferentiableAttr *attr);
876874
void visitNoDerivativeAttr(NoDerivativeAttr *attr);
877875
};
878876
} // end anonymous namespace
@@ -2887,6 +2885,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
28872885

28882886
AbstractFunctionDecl *original = dyn_cast<AbstractFunctionDecl>(D);
28892887
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
2888+
if (asd->getImplInfo().isSimpleStored() &&
2889+
(attr->getJVP() || attr->getVJP())) {
2890+
diagnoseAndRemoveAttr(attr,
2891+
diag::differentiable_attr_stored_property_variable_unsupported);
2892+
return;
2893+
}
28902894
// When used directly on a storage decl (stored/computed property or
28912895
// subscript), the getter is currently inferred to be `@differentiable`.
28922896
// TODO(TF-129): Infer setter to also be `@differentiable` after
@@ -3570,23 +3574,6 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {
35703574
// TypeChecker::checkFunctionBodyCompilerEvaluable().
35713575
}
35723576

3573-
// SWIFT_ENABLE_TENSORFLOW
3574-
void AttributeChecker::visitFieldwiseDifferentiableAttr(
3575-
FieldwiseDifferentiableAttr *attr) {
3576-
auto *structDecl = dyn_cast<StructDecl>(D);
3577-
if (!structDecl) {
3578-
diagnoseAndRemoveAttr(attr,
3579-
diag::fieldwise_differentiable_only_on_differentiable_structs);
3580-
return;
3581-
}
3582-
if (!conformsToDifferentiableInModule(
3583-
structDecl->getDeclaredInterfaceType(), D->getModuleContext())) {
3584-
diagnoseAndRemoveAttr(attr,
3585-
diag::fieldwise_differentiable_only_on_differentiable_structs);
3586-
return;
3587-
}
3588-
}
3589-
35903577
// SWIFT_ENABLE_TENSORFLOW
35913578
void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
35923579
auto *vd = dyn_cast<VarDecl>(D);

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,6 @@ namespace {
13031303
UNINTERESTING_ATTR(Differentiable)
13041304
UNINTERESTING_ATTR(Differentiating)
13051305
UNINTERESTING_ATTR(CompilerEvaluable)
1306-
UNINTERESTING_ATTR(FieldwiseDifferentiable)
13071306
UNINTERESTING_ATTR(NoDerivative)
13081307

13091308
// These can't appear on overridable declarations.

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,23 @@ struct S {
2929
}
3030

3131
extension S : Differentiable, VectorNumeric {
32+
struct TangentVector: Differentiable, VectorNumeric {
33+
var dp: Float
34+
}
35+
typealias AllDifferentiableVariables = S
3236
static var zero: S { return S(p: 0) }
3337
typealias Scalar = Float
3438
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
3539
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
3640
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
3741

38-
typealias TangentVector = S
42+
func moved(along direction: TangentVector) -> S {
43+
return S(p: p + direction.dp)
44+
}
3945
}
4046

4147
// expected-error @+2 {{function is not differentiable}}
42-
// expected-note @+1 {{property is not differentiable}}
48+
// expected-note @+1 {{property cannot be differentiated because 'S.TangentVector' does not have a member named 'p'}}
4349
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
4450

4551
struct NoDerivativeProperty : Differentiable {

test/AutoDiff/derived_differentiable_properties.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ public struct Foo : Differentiable {
66
public var a: Float
77
}
88

9-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable {
9+
// CHECK-AST-LABEL: public struct Foo : Differentiable {
1010
// CHECK-AST: @differentiable
1111
// CHECK-AST: public var a: Float
1212
// CHECK-AST: internal init(a: Float)
13-
// CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables
13+
// CHECK-AST: public struct AllDifferentiableVariables
1414
// CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables
1515
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
1616
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
@@ -25,7 +25,7 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
2525
x.a + x.a
2626
}
2727

28-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
28+
// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
2929
// CHECK-AST: internal var a: Float
3030
// CHECK-AST: internal init(a: Float)
3131
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
@@ -36,11 +36,11 @@ struct TestNoDerivative : Differentiable {
3636
@noDerivative var technicallyDifferentiable: Float
3737
}
3838

39-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestNoDerivative : Differentiable {
39+
// CHECK-AST-LABEL: internal struct TestNoDerivative : Differentiable {
4040
// CHECK-AST: var w: Float
4141
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
4242
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
43-
// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric
43+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric
4444
// CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables
4545
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
4646
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
@@ -50,11 +50,11 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
5050
@noDerivative var technicallyDifferentiable: Float
5151
}
5252

53-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
53+
// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
5454
// CHECK-AST: var w: Float
5555
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
5656
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
57-
// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric
57+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric
5858
// CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables
5959
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
6060
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
@@ -66,7 +66,7 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
6666
// TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized.
6767
// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized.
6868

69-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
69+
// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
7070
// CHECK-AST: internal var x: T.TangentVector
7171
// CHECK-AST: internal init(x: T.TangentVector)
7272
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
@@ -81,7 +81,7 @@ public struct ConditionallyDifferentiable<T> {
8181
}
8282
extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}
8383

84-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct ConditionallyDifferentiable<T> {
84+
// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
8585
// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
8686
// CHECK-AST: public let x: T
8787
// CHECK-AST: internal init(x: T)

test/AutoDiff/differentiable_attr_silgen.swift

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -76,43 +76,6 @@ public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)
7676

7777
// CHECK-LABEL: sil [ossa] @dhasvjp
7878

79-
//===----------------------------------------------------------------------===//
80-
// Stored property
81-
//===----------------------------------------------------------------------===//
82-
83-
struct DiffStoredProp {
84-
@differentiable(wrt: (self), jvp: storedPropJVP, vjp: storedPropVJP)
85-
let storedProp: Float
86-
87-
@_silgen_name("storedPropJVP")
88-
func storedPropJVP() -> (Float, (DiffStoredProp) -> Float) {
89-
fatalError("unimplemented")
90-
}
91-
92-
@_silgen_name("storedPropVJP")
93-
func storedPropVJP() -> (Float, (Float) -> DiffStoredProp) {
94-
fatalError("unimplemented")
95-
}
96-
}
97-
98-
extension DiffStoredProp : VectorNumeric {
99-
static var zero: DiffStoredProp { fatalError("unimplemented") }
100-
static func + (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
101-
fatalError("unimplemented")
102-
}
103-
static func - (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
104-
fatalError("unimplemented")
105-
}
106-
typealias Scalar = Float
107-
static func * (lhs: Float, rhs: DiffStoredProp) -> DiffStoredProp {
108-
fatalError("unimplemented")
109-
}
110-
}
111-
112-
extension DiffStoredProp : Differentiable {
113-
typealias TangentVector = DiffStoredProp
114-
}
115-
11679
//===----------------------------------------------------------------------===//
11780
// Computed property
11881
//===----------------------------------------------------------------------===//

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
// RUN: %target-swift-frontend -typecheck -verify %s
22

33
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
4-
let global: Float = 1
4+
let globalConst: Float = 1
5+
6+
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
7+
var globalVar: Float = 1
58

69
func testLocalVariables() {
710
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
@@ -225,25 +228,18 @@ class Foo {
225228
}
226229

227230
struct JVPStruct {
231+
@differentiable
228232
let p: Float
229233

230-
@differentiable(wrt: (self), jvp: storedPropJVP)
231-
let storedImmutableOk: Float
232-
233-
// expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
234-
@differentiable(wrt: (self), jvp: storedPropJVP)
235-
let storedImmutableWrongType: Double
236-
237-
@differentiable(wrt: (self), jvp: storedPropJVP)
238-
var storedMutableOk: Float
239-
240-
// expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
241-
@differentiable(wrt: (self), jvp: storedPropJVP)
242-
var storedMutableWrongType: Double
234+
// expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
235+
@differentiable(wrt: (self), jvp: funcJVP)
236+
func funcWrongType() -> Double {
237+
fatalError("unimplemented")
238+
}
243239
}
244240

245241
extension JVPStruct {
246-
func storedPropJVP() -> (Float, (JVPStruct) -> Float) {
242+
func funcJVP() -> (Float, (JVPStruct) -> Float) {
247243
fatalError("unimplemented")
248244
}
249245
}
@@ -383,23 +379,15 @@ func vjpNonDiffResult2(x: Float) -> (Float, Int) {
383379
struct VJPStruct {
384380
let p: Float
385381

386-
@differentiable(vjp: storedPropVJP)
387-
let storedImmutableOk: Float
388-
389-
// expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
390-
@differentiable(vjp: storedPropVJP)
391-
let storedImmutableWrongType: Double
392-
393-
@differentiable(vjp: storedPropVJP)
394-
var storedMutableOk: Float
395-
396-
// expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
397-
@differentiable(vjp: storedPropVJP)
398-
var storedMutableWrongType: Double
382+
// expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
383+
@differentiable(vjp: funcVJP)
384+
func funcWrongType() -> Double {
385+
fatalError("unimplemented")
386+
}
399387
}
400388

401389
extension VJPStruct {
402-
func storedPropVJP() -> (Float, (Float) -> VJPStruct) {
390+
func funcVJP() -> (Float, (Float) -> VJPStruct) {
403391
fatalError("unimplemented")
404392
}
405393
}

test/AutoDiff/differentiating_attr_type_checking.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,23 @@ func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float)
295295
func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
296296
return (x, { $0 })
297297
}
298+
299+
// Test usage of `@differentiable` on a stored property
300+
struct PropertyDiff : Differentiable & AdditiveArithmetic {
301+
// expected-error @+1 {{'jvp:' or 'vjp:' cannot be specified for stored properties}}
302+
@differentiable(vjp: vjpPropertyA)
303+
var a: Float = 1
304+
typealias TangentVector = PropertyDiff
305+
typealias AllDifferentiableVariables = PropertyDiff
306+
func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) {
307+
(.zero, { _ in .zero })
308+
}
309+
}
310+
311+
@differentiable
312+
func f(_ x: PropertyDiff) -> Float {
313+
return x.a
314+
}
315+
316+
let a = gradient(at: PropertyDiff(), in: f)
317+
print(a)

0 commit comments

Comments
 (0)