Skip to content

Commit a5bdbd9

Browse files
authored
[AutoDiff] Fix derivative function configurations for accessors. (#31669)
For accessors: make `AbstractFunctionDecl::getDerivativeFunctionConfigurations` resolve configurations from parent storage declaration `@differentiable` attributes. Fixes "no `@differentiable` attribute" non-differentiability error for accessors whose parent storage declaration `@differentiable` attributes have not been type-checked (e.g. because the storage declarations are in another file). Add protocol requirement and class member storage declaration tests. Resolves TF-1234.
1 parent dee6c0b commit a5bdbd9

File tree

4 files changed

+73
-6
lines changed

4 files changed

+73
-6
lines changed

lib/AST/Decl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7111,10 +7111,19 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
71117111
ArrayRef<AutoDiffConfig>
71127112
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
71137113
prepareDerivativeFunctionConfigurations();
7114+
71147115
// Resolve derivative function configurations from `@differentiable`
71157116
// attributes by type-checking them.
71167117
for (auto *diffAttr : getAttrs().getAttributes<DifferentiableAttr>())
71177118
(void)diffAttr->getParameterIndices();
7119+
// For accessors: resolve derivative function configurations from storage
7120+
// `@differentiable` attributes by type-checking them.
7121+
if (auto *accessor = dyn_cast<AccessorDecl>(this)) {
7122+
auto *storage = accessor->getStorage();
7123+
for (auto *diffAttr : storage->getAttrs().getAttributes<DifferentiableAttr>())
7124+
(void)diffAttr->getParameterIndices();
7125+
}
7126+
71187127
// Load derivative configurations from imported modules.
71197128
auto &ctx = getASTContext();
71207129
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,9 @@ emitDerivativeFunctionReference(
656656
auto loc = witnessMethod->getLoc();
657657
auto requirementDeclRef = witnessMethod->getMember();
658658
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
659-
// If requirement declaration does not have any `@differentiable`
660-
// attributes, produce an error.
661-
if (!requirementDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
659+
// If requirement declaration does not have any derivative function
660+
// configurations, produce an error.
661+
if (requirementDecl->getDerivativeFunctionConfigurations().empty()) {
662662
context.emitNondifferentiabilityError(
663663
original, invoker, diag::autodiff_protocol_member_not_differentiable);
664664
return None;
@@ -701,9 +701,9 @@ emitDerivativeFunctionReference(
701701
auto loc = classMethod->getLoc();
702702
auto methodDeclRef = classMethod->getMember();
703703
auto *methodDecl = methodDeclRef.getAbstractFunctionDecl();
704-
// If method declaration does not have any `@differentiable` attributes,
705-
// produce an error.
706-
if (!methodDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
704+
// If method declaration does not have any derivative function
705+
// configurations, produce an error.
706+
if (methodDecl->getDerivativeFunctionConfigurations().empty()) {
707707
context.emitNondifferentiabilityError(
708708
original, invoker, diag::autodiff_class_member_not_differentiable);
709709
return None;

test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ protocol Protocol: Differentiable {
44
// Test cross-file `@differentiable` attribute.
55
@differentiable(wrt: self)
66
func identityDifferentiableAttr() -> Self
7+
8+
// Test `@differentiable` propagation from storage declaration to accessors.
9+
@differentiable
10+
var property: Float { get set }
11+
12+
// Test `@differentiable` propagation from storage declaration to accessors.
13+
@differentiable
14+
subscript() -> Float { get set }
715
}
816

917
extension Protocol {
@@ -17,3 +25,19 @@ extension Protocol {
1725
fatalError()
1826
}
1927
}
28+
29+
class Class: Differentiable {
30+
// Test `@differentiable` propagation from storage declaration to accessors.
31+
@differentiable
32+
var property: Float {
33+
get { 1 }
34+
set {}
35+
}
36+
37+
// Test `@differentiable` propagation from storage declaration to accessors.
38+
@differentiable
39+
subscript() -> Float {
40+
get { 1 }
41+
set {}
42+
}
43+
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,37 @@ func crossFileDerivativeAttr<T: Protocol>(
2323
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
2424
return input.identityDerivativeAttr()
2525
}
26+
27+
// TF-1234: Test `@differentiable` propagation from protocol requirement storage
28+
// declarations to their accessors in other file.
29+
30+
@differentiable
31+
func protocolRequirementGetters<T: Protocol>(_ x: T) -> Float {
32+
x.property + x[]
33+
}
34+
35+
// TODO(TF-1184): Make `@differentiable` on storage declarations propagate to
36+
// the setter in addition to the getter.
37+
@differentiable
38+
func protocolRequirementSetters<T: Protocol>(_ x: inout T, _ newValue: Float) {
39+
// expected-error @+2 {{expression is not differentiable}}
40+
// expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
41+
x.property = newValue
42+
// expected-error @+2 {{expression is not differentiable}}
43+
// expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
44+
x[] = newValue
45+
}
46+
47+
// TF-1234: Test `@differentiable` propagation from class member storage
48+
// declarations to their accessors in other file.
49+
50+
@differentiable
51+
func classRequirementGetters(_ x: Class) -> Float {
52+
x.property + x[]
53+
}
54+
55+
@differentiable
56+
func classRequirementSetters(_ x: inout Class, _ newValue: Float) {
57+
x.property = newValue
58+
x[] = newValue
59+
}

0 commit comments

Comments
 (0)