Skip to content

[AutoDiff] [Sema] Include certain 'let' properties in 'Differentiable' derived conformances. #33700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2806,15 +2806,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
"requires 'wrappedValue' in property wrapper %0 to be mutable; "
"add an explicit '@noDerivative' attribute"
"requires 'wrappedValue' in property wrapper %0 to be mutable or have a "
"non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute"
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
(/*wrapperType*/ Identifier, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
"requires all stored properties not marked with `@noDerivative` to be "
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
"mutable or have a non-mutating 'move(along:)'; use 'var' instead, or "
"add an explicit '@noDerivative' attribute "
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))

Expand Down
56 changes: 42 additions & 14 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,37 @@

using namespace swift;

/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
/// conforming property.
///
/// If the given property is a `var`, return true because `move(along:)` can be
/// invoked regardless. Otherwise, return true if and only if the property's
/// type's 'Differentiable.move(along:)' witness is non-mutating.
static bool canInvokeMoveAlongOnProperty(
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
assert(diffableConformance && "Property must conform to 'Differentiable'");
// `var` always supports `move(along:)` since it is mutable.
if (vd->getIntroducer() == VarDecl::Introducer::Var)
return true;
// When the property is a `let`, the only case that would be supported is when
// it has a `move(along:)` protocol requirement witness that is non-mutating.
auto interfaceType = vd->getInterfaceType();
auto &C = vd->getASTContext();
auto witness = diffableConformance.getWitnessByName(
interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
if (!witness)
return false;
auto *decl = cast<FuncDecl>(witness.getDecl());
return decl->isNonMutating();
}

/// Get the stored properties of a nominal type that are relevant for
/// differentiation, except the ones tagged `@noDerivative`.
static void
getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
SmallVectorImpl<VarDecl *> &result,
bool includeLetProperties = false) {
getStoredPropertiesForDifferentiation(
NominalTypeDecl *nominal, DeclContext *DC,
SmallVectorImpl<VarDecl *> &result,
bool includeLetPropertiesWithNonmutatingMoveAlong = false) {
auto &C = nominal->getASTContext();
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
for (auto *vd : nominal->getStoredProperties()) {
Expand All @@ -53,15 +78,18 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
// Skip stored properties with `@noDerivative` attribute.
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Skip `let` stored properties if requested.
// `mutating func move(along:)` cannot be synthesized to update `let`
// properties.
if (!includeLetProperties && vd->isLet())
continue;
if (vd->getInterfaceType()->hasError())
continue;
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
if (!TypeChecker::conformsToProtocol(varType, diffableProto, nominal))
auto conformance = TypeChecker::conformsToProtocol(
varType, diffableProto, nominal);
if (!conformance)
continue;
// Skip `let` stored properties with a mutating `move(along:)` if requested.
// `mutating func move(along:)` cannot be synthesized to update `let`
// properties.
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
!canInvokeMoveAlongOnProperty(vd, conformance))
continue;
result.push_back(vd);
}
Expand Down Expand Up @@ -782,18 +810,18 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
continue;
// Check whether to diagnose stored property.
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
bool conformsToDifferentiable =
!TypeChecker::conformsToProtocol(varType, diffableProto, nominal)
.isInvalid();
auto diffableConformance =
TypeChecker::conformsToProtocol(varType, diffableProto, nominal);
// If stored property should not be diagnosed, continue.
if (conformsToDifferentiable && !vd->isLet())
if (diffableConformance &&
canInvokeMoveAlongOnProperty(vd, diffableConformance))
continue;
// Otherwise, add an implicit `@noDerivative` attribute.
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
assert(loc.isValid() && "Expected valid source location");
// Diagnose properties that do not conform to `Differentiable`.
if (!conformsToDifferentiable) {
if (!diffableConformance) {
Context.Diags
.diagnose(
loc,
Expand Down
48 changes: 42 additions & 6 deletions test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,70 @@ func testEmpty() {
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
}

protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {}
extension DifferentiableWithNonmutatingMoveAlong {
func move(along _: TangentVector) {}
}

class EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong {
typealias TangentVector = Empty.TangentVector
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
static func proof_that_i_have_nonmutating_move_along() {
let empty = EmptyWithInheritedNonmutatingMoveAlong()
empty.move(along: .init())
}
}

class EmptyWrapper<T: Differentiable & AnyObject>: Differentiable {}
func testEmptyWrapper() {
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
assertConformsToAdditiveArithmetic(EmptyWrapper<Empty>.TangentVector.self)
}

// Test structs with `let` stored properties.
// Derived conformances fail because `mutating func move` requires all stored
// properties to be mutable.
class ImmutableStoredProperties: Differentiable {
class ImmutableStoredProperties<T: Differentiable & AnyObject>: Differentiable {
var okay: Float

// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let nondiff: Int

// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let diff: Float

init() {
let letClass: Empty // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'.

let letClassWithInheritedNonmutatingMoveAlong: EmptyWithInheritedNonmutatingMoveAlong

// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let letClassGeneric: T // Error due to lack of non-mutating 'move(along:)'.

let letClassWrappingGeneric: EmptyWrapper<T> // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'.

init(letClassGeneric: T) {
okay = 0
nondiff = 0
diff = 0
letClass = Empty()
self.letClassGeneric = letClassGeneric
self.letClassWrappingGeneric = EmptyWrapper<T>()
}
}
func testImmutableStoredProperties() {
_ = ImmutableStoredProperties.TangentVector(okay: 1)
_ = ImmutableStoredProperties<Empty>.TangentVector(
okay: 1,
letClass: Empty.TangentVector(),
letClassWithInheritedNonmutatingMoveAlong: Empty.TangentVector(),
letClassWrappingGeneric: EmptyWrapper<Empty>.TangentVector())
}
class MutableStoredPropertiesWithInitialValue: Differentiable {
var x = Float(1)
var y = Double(1)
}
// Test class with both an empty constructor and memberwise initializer.
class AllMixedStoredPropertiesHaveInitialValue: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let x = Float(1)
var y = Float(1)
// Memberwise initializer should be `init(y:)` since `x` is immutable.
Expand Down Expand Up @@ -550,7 +586,7 @@ struct Generic<T> {}
extension Generic: Differentiable where T: Differentiable {}

class WrappedProperties: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}}
@ImmutableWrapper var immutableInt: Generic<Int> = Generic()

// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
Expand Down
49 changes: 44 additions & 5 deletions test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@ func testEmpty() {
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
}

struct EmptyWithConcreteNonmutatingMoveAlong: Differentiable {
typealias TangentVector = Empty.TangentVector
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
func move(along _: TangentVector) {}
static func proof_that_i_have_nonmutating_move_along() {
let empty = Self()
empty.move(along: .init())
}
}

protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {}
extension DifferentiableWithNonmutatingMoveAlong {
func move(along _: TangentVector) {}
}

struct EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong {
typealias TangentVector = Empty.TangentVector
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
static func proof_that_i_have_nonmutating_move_along() {
let empty = Self()
empty.move(along: .init())
}
}

class EmptyClass: Differentiable {}
func testEmptyClass() {
assertConformsToAdditiveArithmetic(EmptyClass.TangentVector.self)
}

// Test interaction with `AdditiveArithmetic` derived conformances.
// Previously, this crashed due to duplicate memberwise initializer synthesis.
struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {}
Expand All @@ -21,22 +50,32 @@ struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {}
struct ImmutableStoredProperties: Differentiable {
var okay: Float

// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let nondiff: Int

// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let diff: Float

let nonmutatingMoveAlongStruct: EmptyWithConcreteNonmutatingMoveAlong

let inheritedNonmutatingMoveAlongStruct: EmptyWithInheritedNonmutatingMoveAlong

let diffClass: EmptyClass // No error on class-bound `let` with a non-mutating `move(along:)`.
}
func testImmutableStoredProperties() {
_ = ImmutableStoredProperties.TangentVector(okay: 1)
_ = ImmutableStoredProperties.TangentVector(
okay: 1,
nonmutatingMoveAlongStruct: Empty.TangentVector(),
inheritedNonmutatingMoveAlongStruct: Empty.TangentVector(),
diffClass: EmptyClass.TangentVector())
}
struct MutableStoredPropertiesWithInitialValue: Differentiable {
var x = Float(1)
var y = Double(1)
}
// Test struct with both an empty constructor and memberwise initializer.
struct AllMixedStoredPropertiesHaveInitialValue: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let x = Float(1)
var y = Float(1)
// Memberwise initializer should be `init(y:)` since `x` is immutable.
Expand Down Expand Up @@ -363,7 +402,7 @@ struct Generic<T> {}
extension Generic: Differentiable where T: Differentiable {}

struct WrappedProperties: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}}
@ImmutableWrapper var immutableInt: Generic<Int>

// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
Expand Down
15 changes: 15 additions & 0 deletions test/AutoDiff/validation-test/class_differentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,19 @@ ClassTests.test("ClassProperties") {
gradient(at: Super(base: 2)) { foo in foo.squared })
}

ClassTests.test("LetProperties") {
final class Foo: Differentiable {
var x: Tracked<Float>
init(x: Tracked<Float>) { self.x = x }
}
final class Bar: Differentiable {
let x = Foo(x: 2)
}
let bar = Bar()
let grad = gradient(at: bar) { bar in (bar.x.x * bar.x.x).value }
expectEqual(Bar.TangentVector(x: .init(x: 6.0)), grad)
bar.move(along: grad)
expectEqual(8.0, bar.x.x)
}

runAllTests()