Skip to content

Commit 7551278

Browse files
authored
[AutoDiff] Make @noDerivative attribute imply non-varying semantics. (#29543)
Previously, `@noDerivative` could only be declared on stored properties in `Differentiable`-conforming structs and classes. Now, `@noDerivative` can be declared on all function-like declarations: `func`, `var`, `init`, and `subscript`. `@noDerivative` attribute now also implies `@_semantics("autodiff.nonvarying")`. SIL values produced from `@noDerivative` declarations will never be marked as varying by differentiable activity analysis. These values do not need a derivative. Marking declarations as `@noDerivative` is a usability improvement over using `withoutDerivative(at:)` at use sites. However, `@noDerivative` is invasive because it requires annotation on original declarations. There may exist a more elegant solution.
1 parent e3409e4 commit 7551278

File tree

9 files changed

+178
-55
lines changed

9 files changed

+178
-55
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
562562
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
563563
NotSerialized, 102)
564564
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
565-
OnVar |
565+
OnAbstractFunction | OnVar | OnSubscript |
566566
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
567567
103)
568568
// SWIFT_ENABLE_TENSORFLOW END

lib/Sema/TypeCheckAttr.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4860,29 +4860,11 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {
48604860

48614861
// SWIFT_ENABLE_TENSORFLOW
48624862
void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
4863-
auto *vd = dyn_cast<VarDecl>(D);
4864-
if (attr->isImplicit())
4865-
return;
4866-
if (!vd || vd->isStatic()) {
4867-
diagnoseAndRemoveAttr(attr,
4868-
diag::noderivative_only_on_differentiable_struct_or_class_fields);
4869-
return;
4870-
}
4871-
auto *nominal = vd->getDeclContext()->getSelfNominalTypeDecl();
4872-
if (!nominal || (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))) {
4873-
diagnoseAndRemoveAttr(attr,
4874-
diag::noderivative_only_on_differentiable_struct_or_class_fields);
4875-
return;
4876-
}
4877-
// Find any `Differentiable` conformance for the nominal type. If no such
4878-
// conformance exists, emit an error.
4879-
auto *diffProto =
4880-
nominal->getASTContext().getProtocol(KnownProtocolKind::Differentiable);
4881-
auto conf = nominal->getModuleContext()->lookupConformance(
4882-
nominal->getDeclaredInterfaceType(), diffProto);
4883-
if (!conf) {
4884-
diagnoseAndRemoveAttr(
4885-
attr, diag::noderivative_only_on_differentiable_struct_or_class_fields);
4886-
return;
4887-
}
4863+
auto &ctx = D->getASTContext();
4864+
// `@noDerivative` implies non-varying semantics for differentiable activity
4865+
// analysis. SIL values produced from references to `@noDerivative`
4866+
// declarations will not be marked as varying; these values do not need a
4867+
// derivative.
4868+
D->getAttrs().add(
4869+
new (ctx) SemanticsAttr("autodiff.nonvarying", /*implicit*/ true));
48884870
}

test/AutoDiff/downstream/derived_differentiable.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct TestNoDerivative : EuclideanDifferentiable {
4242

4343
// CHECK-AST-LABEL: internal struct TestNoDerivative : EuclideanDifferentiable {
4444
// CHECK-AST: var w: Float
45-
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
45+
// CHECK-AST: @noDerivative @_semantics("autodiff.nonvarying") internal var technicallyDifferentiable: Float
4646
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
4747
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions
4848
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
@@ -55,7 +55,7 @@ struct TestPointwiseMultiplicative : Differentiable {
5555

5656
// CHECK-AST-LABEL: internal struct TestPointwiseMultiplicative : Differentiable {
5757
// CHECK-AST: var w: PointwiseMultiplicativeDummy
58-
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
58+
// CHECK-AST: @noDerivative @_semantics("autodiff.nonvarying") internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
5959
// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
6060
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
6161
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector
@@ -68,7 +68,7 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
6868

6969
// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
7070
// CHECK-AST: var w: Float
71-
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
71+
// CHECK-AST: @noDerivative @_semantics("autodiff.nonvarying") internal var technicallyDifferentiable: Float
7272
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
7373
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
7474
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector

test/AutoDiff/downstream/noderivative-attr.swift

Lines changed: 0 additions & 24 deletions
This file was deleted.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: %target-swift-frontend -emit-silgen -verify %s | %FileCheck %s
2+
// REQUIRES: asserts
3+
4+
@noDerivative var flag: Bool
5+
6+
struct NotDifferentiable {
7+
@noDerivative var stored: Float
8+
9+
@noDerivative
10+
var computedProperty: Float {
11+
get { 1 }
12+
set {}
13+
_modify { yield &stored }
14+
}
15+
16+
@noDerivative
17+
func instanceMethod(_ x: Float) -> Float { x }
18+
19+
@noDerivative
20+
static func staticMethod(_ x: Float) -> Float { x }
21+
22+
@noDerivative
23+
subscript(_ x: Float) -> Float {
24+
get { 1 }
25+
set {}
26+
_modify { yield &stored }
27+
}
28+
}
29+
30+
// CHECK-LABEL: struct NotDifferentiable {
31+
// CHECK: @noDerivative @_hasStorage @_semantics("autodiff.nonvarying") var stored: Float { get set }
32+
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") var computedProperty: Float { get set _modify }
33+
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") func instanceMethod(_ x: Float) -> Float
34+
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") static func staticMethod(_ x: Float) -> Float
35+
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") subscript(x: Float) -> Float { get set _modify }
36+
// CHECK: }
37+
38+
// CHECK-LABEL: // NotDifferentiable.computedProperty.getter
39+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvg : $@convention(method) (NotDifferentiable) -> Float
40+
41+
// CHECK-LABEL: // NotDifferentiable.computedProperty.setter
42+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvs : $@convention(method) (Float, @inout NotDifferentiable) -> ()
43+
44+
// CHECK-LABEL: // NotDifferentiable.computedProperty.modify
45+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvM : $@yield_once @convention(method) (@inout NotDifferentiable) -> @yields @inout Float
46+
47+
// CHECK-LABEL: // NotDifferentiable.instanceMethod(_:)
48+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV14instanceMethodyS2fF : $@convention(method) (Float, NotDifferentiable) -> Float
49+
50+
// CHECK-LABEL: // static NotDifferentiable.staticMethod(_:)
51+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV12staticMethodyS2fFZ : $@convention(method) (Float, @thin NotDifferentiable.Type) -> Float
52+
53+
// CHECK-LABEL: // NotDifferentiable.subscript.getter
54+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fcig : $@convention(method) (Float, NotDifferentiable) -> Float
55+
56+
// CHECK-LABEL: // NotDifferentiable.subscript.setter
57+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fcis : $@convention(method) (Float, Float, @inout NotDifferentiable) -> ()
58+
59+
// CHECK-LABEL: // NotDifferentiable.subscript.modify
60+
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fciM : $@yield_once @convention(method) (Float, @inout NotDifferentiable) -> @yields @inout Float
61+
62+
struct Bar: Differentiable {
63+
@noDerivative var stored: Float
64+
}
65+
66+
// Test TF-152: derived conformances "no interface type set" crasher.
67+
struct TF_152: Differentiable {
68+
@differentiable(wrt: bar)
69+
func applied(to input: Float, bar: TF_152_Bar) -> Float {
70+
return input
71+
}
72+
}
73+
struct TF_152_Bar: Differentiable {
74+
@noDerivative let dense: Float
75+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
2+
// REQUIRES: asserts
3+
4+
// Test approaches for affecting activity analysis (non-varying semantics):
5+
// - `@noDerivative` on declaration
6+
// - `@_semantics("autodiff.nonvarying")` on declaration
7+
// - `withoutDerivative(at:)` at use site
8+
9+
extension Float {
10+
// No non-varying semantics.
11+
var int: Int { Int(self) }
12+
13+
// Non-varying semantics.
14+
@noDerivative
15+
var intNoDerivative: Int { int }
16+
17+
// Non-varying semantics.
18+
@_semantics("autodiff.nonvarying")
19+
var intNonvarying: Int { int }
20+
}
21+
22+
// expected-error @+1 {{function is not differentiable}}
23+
@differentiable
24+
@_silgen_name("id")
25+
// expected-note @+1 {{when differentiating this function definition}}
26+
func id(_ x: Float) -> Float {
27+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
28+
return Float(x.int)
29+
}
30+
31+
// CHECK-LABEL: [AD] Activity info for id at (source=0 parameters=(0))
32+
// CHECK: bb0:
33+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
34+
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
35+
// CHECK: [NONE] // function_ref Float.int.getter
36+
// CHECK: [ACTIVE] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
37+
// CHECK: [NONE] // function_ref Float.init(_:)
38+
// CHECK: [ACTIVE] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
39+
40+
@differentiable
41+
@_silgen_name("idWithoutDerivativeAt")
42+
func idWithoutDerivativeAt(_ x: Float) -> Float {
43+
return Float(withoutDerivative(at: x.int))
44+
}
45+
46+
// CHECK-LABEL: [AD] Activity info for idWithoutDerivativeAt at (source=0 parameters=(0))
47+
// CHECK: bb0:
48+
// CHECK: [VARIED] %0 = argument of bb0 : $Float
49+
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
50+
// CHECK: [USEFUL] %3 = alloc_stack $Int
51+
// CHECK: [NONE] // function_ref Float.int.getter
52+
// CHECK: [VARIED] %5 = apply %4(%0) : $@convention(method) (Float) -> Int
53+
// CHECK: [VARIED] %6 = alloc_stack $Int
54+
// CHECK: [NONE] // function_ref withoutDerivative<A>(at:)
55+
// CHECK: [NONE] %9 = apply %8<Int>(%3, %6) : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> @out τ_0_0
56+
// CHECK: [USEFUL] %11 = load [trivial] %3 : $*Int
57+
// CHECK: [NONE] // function_ref Float.init(_:)
58+
// CHECK: [USEFUL] %13 = apply %12(%11, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
59+
60+
@differentiable
61+
@_silgen_name("idNoDerivative")
62+
func idNoDerivative(_ x: Float) -> Float {
63+
return Float(x.intNoDerivative)
64+
}
65+
66+
// CHECK-LABEL: [AD] Activity info for idNoDerivative at (source=0 parameters=(0))
67+
// CHECK: bb0:
68+
// CHECK: [VARIED] %0 = argument of bb0 : $Float
69+
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
70+
// CHECK: [NONE] // function_ref Float.intNoDerivative.getter
71+
// CHECK: [USEFUL] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
72+
// CHECK: [NONE] // function_ref Float.init(_:)
73+
// CHECK: [USEFUL] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
74+
75+
@differentiable
76+
@_silgen_name("idNonvaryingSemantics")
77+
func idNonvaryingSemantics(_ x: Float) -> Float {
78+
return Float(x.intNonvarying)
79+
}
80+
81+
// CHECK-LABEL: [AD] Activity info for idNonvaryingSemantics at (source=0 parameters=(0))
82+
// CHECK: bb0:
83+
// CHECK: [VARIED] %0 = argument of bb0 : $Float
84+
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
85+
// CHECK: [NONE] // function_ref Float.intNonvarying.getter
86+
// CHECK: [USEFUL] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
87+
// CHECK: [NONE] // function_ref Float.init(_:)
88+
// CHECK: [USEFUL] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float

test/IDE/complete_decl_attribute.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct MyStruct {}
7070
// KEYWORD2-NEXT: Keyword/None: quoted[#Func Attribute#]; name=quoted
7171
// KEYWORD2-NEXT: Keyword/None: differentiating[#Func Attribute#]; name=differentiating
7272
// KEYWORD2-NEXT: Keyword/None: compilerEvaluable[#Func Attribute#]; name=compilerEvaluable
73+
// KEYWORD2-NEXT: Keyword/None: noDerivative[#Func Attribute#]; name=noDerivative
7374
// SWIFT_ENABLE_TENSORFLOW END
7475
// KEYWORD2-NOT: Keyword
7576
// KEYWORD2: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
@@ -193,6 +194,7 @@ struct _S {
193194
// ON_METHOD-DAG: Keyword/None: quoted[#Func Attribute#]; name=quoted
194195
// ON_METHOD-DAG: Keyword/None: differentiating[#Func Attribute#]; name=differentiating
195196
// ON_METHOD-DAG: Keyword/None: compilerEvaluable[#Func Attribute#]; name=compilerEvaluable
197+
// ON_METHOD-DAG: Keyword/None: noDerivative[#Func Attribute#]; name=noDerivative
196198
// SWIFT_ENABLE_TENSORFLOW END
197199
// ON_METHOD-NOT: Keyword
198200
// ON_METHOD: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct

test/Sema/class_differentiable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ final class VectorSpaceCustomStruct : DummyAdditiveArithmetic, Differentiable {
503503
}
504504

505505
class StaticNoDerivative : Differentiable {
506-
@noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
506+
@noDerivative static var s: Bool = true
507507
}
508508

509509
final class StaticMembersShouldNotAffectAnything : DummyAdditiveArithmetic, Differentiable {

test/Sema/struct_differentiable.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable {
343343
}
344344

345345
struct StaticNoDerivative : Differentiable {
346-
@noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
346+
@noDerivative static var s: Bool = true
347347
}
348348

349349
struct StaticMembersShouldNotAffectAnything : AdditiveArithmetic, Differentiable {

0 commit comments

Comments
 (0)