Skip to content

[AutoDiff] Revamp 'struct_extract' differentiation strategy. #25151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,8 @@ DECL_ATTR(differentiating, Differentiating,
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
OnAccessor | OnFunc | OnConstructor | OnSubscript,
/* Not serialized */ 90)
SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable,
OnNominalType | UserInaccessible, 91)
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
OnVar, 92)
OnVar, 91)

#undef TYPE_ATTR
#undef DECL_ATTR_ALIAS
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ NOTE(autodiff_opaque_function_not_differentiable,none,
"opaque non-'@differentiable' function is not differentiable", ())
NOTE(autodiff_property_not_differentiable,none,
"property is not differentiable", ())
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
"property cannot be differentiated because '%0.TangentVector' does not "
"have a member named '%1'", (StringRef, StringRef))
NOTE(autodiff_value_defined_here,none,
"value defined here", ())
NOTE(autodiff_when_differentiating_function_call,none,
Expand Down
7 changes: 2 additions & 5 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none,
"layout requirement are not supported by '@differentiable' attribute", ())
ERROR(differentiable_attr_class_unsupported,none,
"class members cannot be marked with '@differentiable'", ())
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
"'jvp:' or 'vjp:' cannot be specified for stored properties", ())
NOTE(protocol_witness_missing_specific_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

Expand Down Expand Up @@ -2806,11 +2808,6 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
"'@noDerivative' is only allowed on stored properties in structure types "
"that declare a conformance to 'Differentiable'", ())

// @_fieldwiseDifferentiable attribute
ERROR(fieldwise_differentiable_only_on_differentiable_structs,none,
"'@_fieldwiseDifferentiable' is only allowed on structure types that "
"conform to 'Differentiable'", ())

//------------------------------------------------------------------------------
// MARK: Type Check Expressions
//------------------------------------------------------------------------------
Expand Down
490 changes: 132 additions & 358 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp

Large diffs are not rendered by default.

8 changes: 0 additions & 8 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,6 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(),
/*Inherited*/ C.AllocateCopy(inherited),
/*GenericParams*/ {}, parentDC);
structDecl->getAttrs().add(
new (C) FieldwiseDifferentiableAttr(/*implicit*/ true));
structDecl->setImplicit();
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);

Expand Down Expand Up @@ -960,12 +958,6 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
if (!getAssociatedType(member, parentDC, id))
return nullptr;

// Since associated types will be derived, we make this struct a fieldwise
// differentiable type.
if (!nominal->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>())
nominal->getAttrs().add(
new (C) FieldwiseDifferentiableAttr(/*implicit*/ true));

// Prevent re-synthesis during repeated calls.
// FIXME: Investigate why this is necessary to prevent duplicate synthesis.
auto lookup = nominal->lookupDirect(id);
Expand Down
25 changes: 6 additions & 19 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ class AttributeEarlyChecker : public AttributeVisitor<AttributeEarlyChecker> {
IGNORED_ATTR(Differentiable)
IGNORED_ATTR(Differentiating)
IGNORED_ATTR(CompilerEvaluable)
IGNORED_ATTR(FieldwiseDifferentiable)
IGNORED_ATTR(NoDerivative)
#undef IGNORED_ATTR

Expand Down Expand Up @@ -872,7 +871,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
void visitDifferentiableAttr(DifferentiableAttr *attr);
void visitDifferentiatingAttr(DifferentiatingAttr *attr);
void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr);
void visitFieldwiseDifferentiableAttr(FieldwiseDifferentiableAttr *attr);
void visitNoDerivativeAttr(NoDerivativeAttr *attr);
};
} // end anonymous namespace
Expand Down Expand Up @@ -2887,6 +2885,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {

AbstractFunctionDecl *original = dyn_cast<AbstractFunctionDecl>(D);
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
if (asd->getImplInfo().isSimpleStored() &&
(attr->getJVP() || attr->getVJP())) {
diagnoseAndRemoveAttr(attr,
diag::differentiable_attr_stored_property_variable_unsupported);
return;
}
// When used directly on a storage decl (stored/computed property or
// subscript), the getter is currently inferred to be `@differentiable`.
// TODO(TF-129): Infer setter to also be `@differentiable` after
Expand Down Expand Up @@ -3570,23 +3574,6 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {
// TypeChecker::checkFunctionBodyCompilerEvaluable().
}

// SWIFT_ENABLE_TENSORFLOW
void AttributeChecker::visitFieldwiseDifferentiableAttr(
FieldwiseDifferentiableAttr *attr) {
auto *structDecl = dyn_cast<StructDecl>(D);
if (!structDecl) {
diagnoseAndRemoveAttr(attr,
diag::fieldwise_differentiable_only_on_differentiable_structs);
return;
}
if (!conformsToDifferentiableInModule(
structDecl->getDeclaredInterfaceType(), D->getModuleContext())) {
diagnoseAndRemoveAttr(attr,
diag::fieldwise_differentiable_only_on_differentiable_structs);
return;
}
}

// SWIFT_ENABLE_TENSORFLOW
void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
auto *vd = dyn_cast<VarDecl>(D);
Expand Down
1 change: 0 additions & 1 deletion lib/Sema/TypeCheckDeclOverride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,6 @@ namespace {
UNINTERESTING_ATTR(Differentiable)
UNINTERESTING_ATTR(Differentiating)
UNINTERESTING_ATTR(CompilerEvaluable)
UNINTERESTING_ATTR(FieldwiseDifferentiable)
UNINTERESTING_ATTR(NoDerivative)

// These can't appear on overridable declarations.
Expand Down
10 changes: 8 additions & 2 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,23 @@ struct S {
}

extension S : Differentiable, VectorNumeric {
struct TangentVector: Differentiable, VectorNumeric {
var dp: Float
}
typealias AllDifferentiableVariables = S
static var zero: S { return S(p: 0) }
typealias Scalar = Float
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }

typealias TangentVector = S
func moved(along direction: TangentVector) -> S {
return S(p: p + direction.dp)
}
}

// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{property is not differentiable}}
// expected-note @+1 {{property cannot be differentiated because 'S.TangentVector' does not have a member named 'p'}}
_ = gradient(at: S(p: 0)) { s in 2 * s.p }

struct NoDerivativeProperty : Differentiable {
Expand Down
18 changes: 9 additions & 9 deletions test/AutoDiff/derived_differentiable_properties.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ public struct Foo : Differentiable {
public var a: Float
}

// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable {
// CHECK-AST-LABEL: public struct Foo : Differentiable {
// CHECK-AST: @differentiable
// CHECK-AST: public var a: Float
// CHECK-AST: internal init(a: Float)
// CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables
// CHECK-AST: public struct AllDifferentiableVariables
// CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
Expand All @@ -25,7 +25,7 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
x.a + x.a
}

// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
// CHECK-AST: internal var a: Float
// CHECK-AST: internal init(a: Float)
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
Expand All @@ -36,11 +36,11 @@ struct TestNoDerivative : Differentiable {
@noDerivative var technicallyDifferentiable: Float
}

// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestNoDerivative : Differentiable {
// CHECK-AST-LABEL: internal struct TestNoDerivative : Differentiable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric
// CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
Expand All @@ -50,11 +50,11 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
@noDerivative var technicallyDifferentiable: Float
}

// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric
// CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
Expand All @@ -66,7 +66,7 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
// TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized.
// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized.

// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
// CHECK-AST: internal var x: T.TangentVector
// CHECK-AST: internal init(x: T.TangentVector)
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
Expand All @@ -81,7 +81,7 @@ public struct ConditionallyDifferentiable<T> {
}
extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}

// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct ConditionallyDifferentiable<T> {
// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
// CHECK-AST: public let x: T
// CHECK-AST: internal init(x: T)
Expand Down
37 changes: 0 additions & 37 deletions test/AutoDiff/differentiable_attr_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,43 +76,6 @@ public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)

// CHECK-LABEL: sil [ossa] @dhasvjp

//===----------------------------------------------------------------------===//
// Stored property
//===----------------------------------------------------------------------===//

struct DiffStoredProp {
@differentiable(wrt: (self), jvp: storedPropJVP, vjp: storedPropVJP)
let storedProp: Float

@_silgen_name("storedPropJVP")
func storedPropJVP() -> (Float, (DiffStoredProp) -> Float) {
fatalError("unimplemented")
}

@_silgen_name("storedPropVJP")
func storedPropVJP() -> (Float, (Float) -> DiffStoredProp) {
fatalError("unimplemented")
}
}

extension DiffStoredProp : VectorNumeric {
static var zero: DiffStoredProp { fatalError("unimplemented") }
static func + (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
fatalError("unimplemented")
}
static func - (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
fatalError("unimplemented")
}
typealias Scalar = Float
static func * (lhs: Float, rhs: DiffStoredProp) -> DiffStoredProp {
fatalError("unimplemented")
}
}

extension DiffStoredProp : Differentiable {
typealias TangentVector = DiffStoredProp
}

//===----------------------------------------------------------------------===//
// Computed property
//===----------------------------------------------------------------------===//
Expand Down
46 changes: 17 additions & 29 deletions test/AutoDiff/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// RUN: %target-swift-frontend -typecheck -verify %s

@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
let global: Float = 1
let globalConst: Float = 1

@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
var globalVar: Float = 1

func testLocalVariables() {
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
Expand Down Expand Up @@ -225,25 +228,18 @@ class Foo {
}

struct JVPStruct {
@differentiable
let p: Float

@differentiable(wrt: (self), jvp: storedPropJVP)
let storedImmutableOk: Float

// expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
@differentiable(wrt: (self), jvp: storedPropJVP)
let storedImmutableWrongType: Double

@differentiable(wrt: (self), jvp: storedPropJVP)
var storedMutableOk: Float

// expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
@differentiable(wrt: (self), jvp: storedPropJVP)
var storedMutableWrongType: Double
// expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
@differentiable(wrt: (self), jvp: funcJVP)
func funcWrongType() -> Double {
fatalError("unimplemented")
}
}

extension JVPStruct {
func storedPropJVP() -> (Float, (JVPStruct) -> Float) {
func funcJVP() -> (Float, (JVPStruct) -> Float) {
fatalError("unimplemented")
}
}
Expand Down Expand Up @@ -383,23 +379,15 @@ func vjpNonDiffResult2(x: Float) -> (Float, Int) {
struct VJPStruct {
let p: Float

@differentiable(vjp: storedPropVJP)
let storedImmutableOk: Float
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For type checking tests like this, I'd actually suggest turning them into computed properties. These tests are pretty important.


// expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
@differentiable(vjp: storedPropVJP)
let storedImmutableWrongType: Double

@differentiable(vjp: storedPropVJP)
var storedMutableOk: Float

// expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
@differentiable(vjp: storedPropVJP)
var storedMutableWrongType: Double
// expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
@differentiable(vjp: funcVJP)
func funcWrongType() -> Double {
fatalError("unimplemented")
}
}

extension VJPStruct {
func storedPropVJP() -> (Float, (Float) -> VJPStruct) {
func funcVJP() -> (Float, (Float) -> VJPStruct) {
fatalError("unimplemented")
}
}
Expand Down
20 changes: 20 additions & 0 deletions test/AutoDiff/differentiating_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,23 @@ func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float)
func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}

// Test usage of `@differentiable` on a stored property
struct PropertyDiff : Differentiable & AdditiveArithmetic {
// expected-error @+1 {{'jvp:' or 'vjp:' cannot be specified for stored properties}}
@differentiable(vjp: vjpPropertyA)
var a: Float = 1
typealias TangentVector = PropertyDiff
typealias AllDifferentiableVariables = PropertyDiff
func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) {
(.zero, { _ in .zero })
}
}

@differentiable
func f(_ x: PropertyDiff) -> Float {
return x.a
}

let a = gradient(at: PropertyDiff(), in: f)
print(a)
Loading