Skip to content

Commit 7897fac

Browse files
committed
Don't allow marking stored props as differentiable
1 parent 1d269ca commit 7897fac

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none,
27332733
"layout requirement are not supported by '@differentiable' attribute", ())
27342734
ERROR(differentiable_attr_class_unsupported,none,
27352735
"class members cannot be marked with '@differentiable'", ())
2736+
ERROR(differentiable_attr_stored_prop_unsupported,none,
2737+
"Stored properties cannot be marked with '@differentiable'", ())
27362738
NOTE(protocol_witness_missing_specific_differentiable_attr,none,
27372739
"candidate is missing attribute '%0'", (StringRef))
27382740

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3234,27 +3234,22 @@ class VJPEmitter final
32343234
}
32353235
// This instruction is active. Determine the appropriate differentiation
32363236
// strategy, and use it.
3237+
// Find the corresponding getter.
3238+
auto *getterDecl = sei->getField()->getGetter();
3239+
assert(getterDecl);
3240+
auto *getterFn = getModule().lookUpFunction(
3241+
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
32373242
auto *structDecl = sei->getStructDecl();
3238-
if (structDecl->getEffectiveAccess() <= AccessLevel::Internal ||
3243+
if (!getterFn ||
32393244
structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
32403245
strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise;
32413246
SILClonerWithScopes::visitStructExtractInst(sei);
32423247
return;
32433248
}
32443249
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
32453250
// strategy.
3251+
assert(getterFn);
32463252
strategies[sei] = StructExtractDifferentiationStrategy::Getter;
3247-
// Find the corresponding getter and its VJP.
3248-
auto *getterDecl = sei->getField()->getGetter();
3249-
assert(getterDecl);
3250-
auto *getterFn = getModule().lookUpFunction(
3251-
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
3252-
if (!getterFn) {
3253-
context.emitNondifferentiabilityError(
3254-
sei, invoker, diag::autodiff_property_not_differentiable);
3255-
errorOccurred = true;
3256-
return;
3257-
}
32583253
SILAutoDiffIndices indices(/*source*/ 0,
32593254
AutoDiffIndexSubset::getDefault(getASTContext(), 1, true));
32603255
auto *attr = context.lookUpDifferentiableAttr(getterFn, indices);
@@ -3304,27 +3299,22 @@ class VJPEmitter final
33043299
}
33053300
// This instruction is active. Determine the appropriate differentiation
33063301
// strategy, and use it.
3302+
// Find the corresponding getter.
3303+
auto *getterDecl = seai->getField()->getGetter();
3304+
assert(getterDecl);
3305+
auto *getterFn = getModule().lookUpFunction(
3306+
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
33073307
auto *structDecl = seai->getStructDecl();
3308-
if (structDecl->getEffectiveAccess() <= AccessLevel::Internal ||
3308+
if (!getterFn ||
33093309
structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
33103310
strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise;
33113311
SILClonerWithScopes::visitStructElementAddrInst(seai);
33123312
return;
33133313
}
33143314
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
33153315
// strategy.
3316+
assert(getterFn);
33163317
strategies[seai] = StructExtractDifferentiationStrategy::Getter;
3317-
// Find the corresponding getter and its VJP.
3318-
auto *getterDecl = seai->getField()->getGetter();
3319-
assert(getterDecl);
3320-
auto *getterFn = getModule().lookUpFunction(
3321-
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
3322-
if (!getterFn) {
3323-
context.emitNondifferentiabilityError(
3324-
seai, invoker, diag::autodiff_property_not_differentiable);
3325-
errorOccurred = true;
3326-
return;
3327-
}
33283318
SILAutoDiffIndices indices(/*source*/ 0,
33293319
AutoDiffIndexSubset::getDefault(getASTContext(), 1, true));
33303320
auto *attr = context.lookUpDifferentiableAttr(getterFn, indices);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2887,6 +2887,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
28872887

28882888
AbstractFunctionDecl *original = dyn_cast<AbstractFunctionDecl>(D);
28892889
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
2890+
if (asd->getImplInfo().isSimpleStored()) {
2891+
diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_prop_unsupported);
2892+
}
28902893
// When used directly on a storage decl (stored/computed property or
28912894
// subscript), the getter is currently inferred to be `@differentiable`.
28922895
// TODO(TF-129): Infer setter to also be `@differentiable` after

test/AutoDiff/differentiating_attr_type_checking.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,23 @@ func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float)
295295
func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
296296
return (x, { $0 })
297297
}
298+
299+
// Test usage of `@differentiable` on a stored property
300+
struct PropertyDiff : Differentiable & AdditiveArithmetic {
301+
// expected-error @+1 {{Stored properties cannot be marked with '@differentiable'}}
302+
@differentiable(vjp: vjpPropertyA)
303+
var a: Float = 1
304+
typealias TangentVector = PropertyDiff
305+
typealias AllDifferentiableVariables = PropertyDiff
306+
func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) {
307+
(.zero, { _ in .zero })
308+
}
309+
}
310+
311+
@differentiable
312+
func f(_ x: PropertyDiff) -> Float {
313+
return x.a
314+
}
315+
316+
let a = gradient(at: PropertyDiff(), in: f)
317+
print(a)

0 commit comments

Comments
 (0)