Skip to content

Commit d0f77e7

Browse files
committed
Revamp struct extraction differentiation semantics.
- Remove `@_fieldwiseDifferentiable`. - Remove `StructExtractDifferentiationStrategy`. - Require `TangentVector` to have a member of the same name.
1 parent 19ecff5 commit d0f77e7

12 files changed

+167
-271
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,8 @@ DECL_ATTR(differentiating, Differentiating,
421421
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
422422
OnAccessor | OnFunc | OnConstructor | OnSubscript,
423423
/* Not serialized */ 90)
424-
SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable,
425-
OnNominalType | UserInaccessible, 91)
426424
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
427-
OnVar, 92)
425+
OnVar, 91)
428426

429427
#undef TYPE_ATTR
430428
#undef DECL_ATTR_ALIAS

include/swift/AST/DiagnosticsSIL.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ NOTE(autodiff_opaque_function_not_differentiable,none,
447447
"opaque non-'@differentiable' function is not differentiable", ())
448448
NOTE(autodiff_property_not_differentiable,none,
449449
"property is not differentiable", ())
450+
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
451+
"property cannot be differentiated because '%0.TangentVector' does not "
452+
"have a member named '%1'", (StringRef, StringRef))
450453
NOTE(autodiff_value_defined_here,none,
451454
"value defined here", ())
452455
NOTE(autodiff_when_differentiating_function_call,none,

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2808,11 +2808,6 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
28082808
"'@noDerivative' is only allowed on stored properties in structure types "
28092809
"that declare a conformance to 'Differentiable'", ())
28102810

2811-
// @_fieldwiseDifferentiable attribute
2812-
ERROR(fieldwise_differentiable_only_on_differentiable_structs,none,
2813-
"'@_fieldwiseDifferentiable' is only allowed on structure types that "
2814-
"conform to 'Differentiable'", ())
2815-
28162811
//------------------------------------------------------------------------------
28172812
// MARK: Type Check Expressions
28182813
//------------------------------------------------------------------------------

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 134 additions & 203 deletions
Large diffs are not rendered by default.

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,6 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
690690
auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(),
691691
/*Inherited*/ C.AllocateCopy(inherited),
692692
/*GenericParams*/ {}, parentDC);
693-
structDecl->getAttrs().add(
694-
new (C) FieldwiseDifferentiableAttr(/*implicit*/ true));
695693
structDecl->setImplicit();
696694
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
697695

@@ -960,12 +958,6 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
960958
if (!getAssociatedType(member, parentDC, id))
961959
return nullptr;
962960

963-
// Since associated types will be derived, we make this struct a fieldwise
964-
// differentiable type.
965-
if (!nominal->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>())
966-
nominal->getAttrs().add(
967-
new (C) FieldwiseDifferentiableAttr(/*implicit*/ true));
968-
969961
// Prevent re-synthesis during repeated calls.
970962
// FIXME: Investigate why this is necessary to prevent duplicate synthesis.
971963
auto lookup = nominal->lookupDirect(id);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ class AttributeEarlyChecker : public AttributeVisitor<AttributeEarlyChecker> {
135135
IGNORED_ATTR(Differentiable)
136136
IGNORED_ATTR(Differentiating)
137137
IGNORED_ATTR(CompilerEvaluable)
138-
IGNORED_ATTR(FieldwiseDifferentiable)
139138
IGNORED_ATTR(NoDerivative)
140139
#undef IGNORED_ATTR
141140

@@ -872,7 +871,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
872871
void visitDifferentiableAttr(DifferentiableAttr *attr);
873872
void visitDifferentiatingAttr(DifferentiatingAttr *attr);
874873
void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr);
875-
void visitFieldwiseDifferentiableAttr(FieldwiseDifferentiableAttr *attr);
876874
void visitNoDerivativeAttr(NoDerivativeAttr *attr);
877875
};
878876
} // end anonymous namespace
@@ -3576,23 +3574,6 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {
35763574
// TypeChecker::checkFunctionBodyCompilerEvaluable().
35773575
}
35783576

3579-
// SWIFT_ENABLE_TENSORFLOW
3580-
void AttributeChecker::visitFieldwiseDifferentiableAttr(
3581-
FieldwiseDifferentiableAttr *attr) {
3582-
auto *structDecl = dyn_cast<StructDecl>(D);
3583-
if (!structDecl) {
3584-
diagnoseAndRemoveAttr(attr,
3585-
diag::fieldwise_differentiable_only_on_differentiable_structs);
3586-
return;
3587-
}
3588-
if (!conformsToDifferentiableInModule(
3589-
structDecl->getDeclaredInterfaceType(), D->getModuleContext())) {
3590-
diagnoseAndRemoveAttr(attr,
3591-
diag::fieldwise_differentiable_only_on_differentiable_structs);
3592-
return;
3593-
}
3594-
}
3595-
35963577
// SWIFT_ENABLE_TENSORFLOW
35973578
void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
35983579
auto *vd = dyn_cast<VarDecl>(D);

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,6 @@ namespace {
13031303
UNINTERESTING_ATTR(Differentiable)
13041304
UNINTERESTING_ATTR(Differentiating)
13051305
UNINTERESTING_ATTR(CompilerEvaluable)
1306-
UNINTERESTING_ATTR(FieldwiseDifferentiable)
13071306
UNINTERESTING_ATTR(NoDerivative)
13081307

13091308
// These can't appear on overridable declarations.

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,23 @@ struct S {
2929
}
3030

3131
extension S : Differentiable, VectorNumeric {
32+
struct TangentVector: Differentiable, VectorNumeric {
33+
var dp: Float
34+
}
35+
typealias AllDifferentiableVariables = S
3236
static var zero: S { return S(p: 0) }
3337
typealias Scalar = Float
3438
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
3539
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
3640
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
3741

38-
typealias TangentVector = S
42+
func moved(along direction: TangentVector) -> S {
43+
return S(p: p + direction.dp)
44+
}
3945
}
4046

47+
// expected-error @+2 {{function is not differentiable}}
48+
// expected-note @+1 {{property cannot be differentiated because 'S.TangentVector' does not have a member named 'p'}}
4149
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
4250

4351
struct NoDerivativeProperty : Differentiable {

test/AutoDiff/derived_differentiable_properties.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ public struct Foo : Differentiable {
66
public var a: Float
77
}
88

9-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable {
9+
// CHECK-AST-LABEL: public struct Foo : Differentiable {
1010
// CHECK-AST: @differentiable
1111
// CHECK-AST: public var a: Float
1212
// CHECK-AST: internal init(a: Float)
13-
// CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables
13+
// CHECK-AST: public struct AllDifferentiableVariables
1414
// CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables
1515
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
1616
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
@@ -25,7 +25,7 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
2525
x.a + x.a
2626
}
2727

28-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
28+
// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
2929
// CHECK-AST: internal var a: Float
3030
// CHECK-AST: internal init(a: Float)
3131
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
@@ -36,11 +36,11 @@ struct TestNoDerivative : Differentiable {
3636
@noDerivative var technicallyDifferentiable: Float
3737
}
3838

39-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestNoDerivative : Differentiable {
39+
// CHECK-AST-LABEL: internal struct TestNoDerivative : Differentiable {
4040
// CHECK-AST: var w: Float
4141
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
4242
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
43-
// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric
43+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric
4444
// CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables
4545
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
4646
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
@@ -50,11 +50,11 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
5050
@noDerivative var technicallyDifferentiable: Float
5151
}
5252

53-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
53+
// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
5454
// CHECK-AST: var w: Float
5555
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
5656
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
57-
// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric
57+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric
5858
// CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables
5959
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
6060
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
@@ -66,7 +66,7 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
6666
// TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized.
6767
// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized.
6868

69-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
69+
// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
7070
// CHECK-AST: internal var x: T.TangentVector
7171
// CHECK-AST: internal init(x: T.TangentVector)
7272
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
@@ -81,7 +81,7 @@ public struct ConditionallyDifferentiable<T> {
8181
}
8282
extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}
8383

84-
// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct ConditionallyDifferentiable<T> {
84+
// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
8585
// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
8686
// CHECK-AST: public let x: T
8787
// CHECK-AST: internal init(x: T)

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import StdlibUnittest
99
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
1010

1111
struct TangentSpace : VectorNumeric {
12-
let dx, dy: Float
12+
let x, y: Float
1313
}
1414

1515
extension TangentSpace : Differentiable {
@@ -26,18 +26,14 @@ struct Space {
2626
}
2727

2828
func vjpX() -> (Float, (Float) -> TangentSpace) {
29-
return (x, { v in TangentSpace(dx: v, dy: 0) } )
29+
return (x, { v in TangentSpace(x: v, y: 0) } )
3030
}
3131

3232
private let storedX: Float
33-
33+
3434
@differentiable
3535
var y: Float
3636

37-
func vjpY() -> (Float, (Float) -> TangentSpace) {
38-
return (y, { v in TangentSpace(dx: 0, dy: v) })
39-
}
40-
4137
init(x: Float, y: Float) {
4238
self.storedX = x
4339
self.y = y
@@ -47,23 +43,23 @@ struct Space {
4743
extension Space : Differentiable {
4844
typealias TangentVector = TangentSpace
4945
func moved(along: TangentSpace) -> Space {
50-
return Space(x: x + along.dx, y: y + along.dy)
46+
return Space(x: x + along.x, y: y + along.y)
5147
}
5248
}
5349

5450
E2EDifferentiablePropertyTests.test("computed property") {
5551
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
5652
return 2 * point.x
5753
}
58-
let expectedGrad = TangentSpace(dx: 2, dy: 0)
54+
let expectedGrad = TangentSpace(x: 2, y: 0)
5955
expectEqual(expectedGrad, actualGrad)
6056
}
6157

6258
E2EDifferentiablePropertyTests.test("stored property") {
6359
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
6460
return 3 * point.y
6561
}
66-
let expectedGrad = TangentSpace(dx: 0, dy: 3)
62+
let expectedGrad = TangentSpace(x: 0, y: 3)
6763
expectEqual(expectedGrad, actualGrad)
6864
}
6965

@@ -85,7 +81,6 @@ E2EDifferentiablePropertyTests.test("generic stored property") {
8581
expectEqual(expectedGrad, actualGrad)
8682
}
8783

88-
@_fieldwiseDifferentiable
8984
struct ProductSpaceSelfTangent : VectorNumeric {
9085
let x, y: Float
9186
}
@@ -110,7 +105,6 @@ extension ProductSpaceOtherTangentTangentSpace : Differentiable {
110105
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
111106
}
112107

113-
@_fieldwiseDifferentiable
114108
struct ProductSpaceOtherTangent {
115109
let x, y: Float
116110
}

test/AutoDiff/separate_cotangent_type.swift

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@ import Glibc
1010

1111
var SeparateTangentTypeTests = TestSuite("SeparateTangentType")
1212

13-
@_fieldwiseDifferentiable
1413
struct DifferentiableSubset : Differentiable {
1514
@differentiable(wrt: self)
1615
var w: Float
1716
@differentiable(wrt: self)
1817
var b: Float
1918
@noDerivative var flag: Bool
2019

21-
@_fieldwiseDifferentiable
2220
struct TangentVector : Differentiable, VectorNumeric {
2321
typealias TangentVector = DifferentiableSubset.TangentVector
2422
var w: Float
@@ -41,12 +39,10 @@ SeparateTangentTypeTests.test("Initialization") {
4139
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
4240
}
4341

44-
// FIXME(SR-9602): If `TangentVector` is not marked
45-
// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
46-
// SeparateTangentTypeTests.test("SomeArithmetics") {
47-
// let x = DifferentiableSubset(w: 0, b: 1, flag: false)
48-
// let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
49-
// expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
50-
// }
42+
SeparateTangentTypeTests.test("SomeArithmetics") {
43+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
44+
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
45+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
46+
}
5147

5248
runAllTests()

test/AutoDiff/witness_table_silgen.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ protocol Proto : Differentiable {
1111
func function3(_ x: Float, _ y: Double) -> Double
1212
}
1313

14-
@_fieldwiseDifferentiable
1514
struct S : Proto, VectorNumeric {
1615
static var zero: S { return S(p: 0) }
1716
typealias Scalar = Float

0 commit comments

Comments
 (0)