Skip to content

[AutoDiff] Make @noDerivative attribute imply non-varying semantics. #29543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
NotSerialized, 102)
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
OnVar |
OnAbstractFunction | OnVar | OnSubscript |
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
103)
// SWIFT_ENABLE_TENSORFLOW END
Expand Down
32 changes: 7 additions & 25 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4860,29 +4860,11 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {

// SWIFT_ENABLE_TENSORFLOW
void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
auto *vd = dyn_cast<VarDecl>(D);
if (attr->isImplicit())
return;
if (!vd || vd->isStatic()) {
diagnoseAndRemoveAttr(attr,
diag::noderivative_only_on_differentiable_struct_or_class_fields);
return;
}
auto *nominal = vd->getDeclContext()->getSelfNominalTypeDecl();
if (!nominal || (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))) {
diagnoseAndRemoveAttr(attr,
diag::noderivative_only_on_differentiable_struct_or_class_fields);
return;
}
// Find any `Differentiable` conformance for the nominal type. If no such
// conformance exists, emit an error.
auto *diffProto =
nominal->getASTContext().getProtocol(KnownProtocolKind::Differentiable);
auto conf = nominal->getModuleContext()->lookupConformance(
nominal->getDeclaredInterfaceType(), diffProto);
if (!conf) {
diagnoseAndRemoveAttr(
attr, diag::noderivative_only_on_differentiable_struct_or_class_fields);
return;
}
auto &ctx = D->getASTContext();
// `@noDerivative` implies non-varying semantics for differentiable activity
// analysis. SIL values produced from references to `@noDerivative`
// declarations will not be marked as varying; these values do not need a
// derivative.
D->getAttrs().add(
new (ctx) SemanticsAttr("autodiff.nonvarying", /*implicit*/ true));
}
6 changes: 3 additions & 3 deletions test/AutoDiff/downstream/derived_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct TestNoDerivative : EuclideanDifferentiable {

// CHECK-AST-LABEL: internal struct TestNoDerivative : EuclideanDifferentiable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: @noDerivative @_semantics("autodiff.nonvarying") internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
Expand All @@ -55,7 +55,7 @@ struct TestPointwiseMultiplicative : Differentiable {

// CHECK-AST-LABEL: internal struct TestPointwiseMultiplicative : Differentiable {
// CHECK-AST: var w: PointwiseMultiplicativeDummy
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
// CHECK-AST: @noDerivative @_semantics("autodiff.nonvarying") internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector
Expand All @@ -68,7 +68,7 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {

// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: @noDerivative @_semantics("autodiff.nonvarying") internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector
Expand Down
24 changes: 0 additions & 24 deletions test/AutoDiff/downstream/noderivative-attr.swift

This file was deleted.

75 changes: 75 additions & 0 deletions test/AutoDiff/downstream/noderivative_attr.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: %target-swift-frontend -emit-silgen -verify %s | %FileCheck %s
// REQUIRES: asserts

@noDerivative var flag: Bool

struct NotDifferentiable {
@noDerivative var stored: Float

@noDerivative
var computedProperty: Float {
get { 1 }
set {}
_modify { yield &stored }
}

@noDerivative
func instanceMethod(_ x: Float) -> Float { x }

@noDerivative
static func staticMethod(_ x: Float) -> Float { x }

@noDerivative
subscript(_ x: Float) -> Float {
get { 1 }
set {}
_modify { yield &stored }
}
}

// CHECK-LABEL: struct NotDifferentiable {
// CHECK: @noDerivative @_hasStorage @_semantics("autodiff.nonvarying") var stored: Float { get set }
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") var computedProperty: Float { get set _modify }
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") func instanceMethod(_ x: Float) -> Float
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") static func staticMethod(_ x: Float) -> Float
// CHECK: @noDerivative @_semantics("autodiff.nonvarying") subscript(x: Float) -> Float { get set _modify }
// CHECK: }

// CHECK-LABEL: // NotDifferentiable.computedProperty.getter
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvg : $@convention(method) (NotDifferentiable) -> Float

// CHECK-LABEL: // NotDifferentiable.computedProperty.setter
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvs : $@convention(method) (Float, @inout NotDifferentiable) -> ()

// CHECK-LABEL: // NotDifferentiable.computedProperty.modify
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV16computedPropertySfvM : $@yield_once @convention(method) (@inout NotDifferentiable) -> @yields @inout Float

// CHECK-LABEL: // NotDifferentiable.instanceMethod(_:)
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV14instanceMethodyS2fF : $@convention(method) (Float, NotDifferentiable) -> Float

// CHECK-LABEL: // static NotDifferentiable.staticMethod(_:)
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableV12staticMethodyS2fFZ : $@convention(method) (Float, @thin NotDifferentiable.Type) -> Float

// CHECK-LABEL: // NotDifferentiable.subscript.getter
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fcig : $@convention(method) (Float, NotDifferentiable) -> Float

// CHECK-LABEL: // NotDifferentiable.subscript.setter
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fcis : $@convention(method) (Float, Float, @inout NotDifferentiable) -> ()

// CHECK-LABEL: // NotDifferentiable.subscript.modify
// CHECK: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @$s17noderivative_attr17NotDifferentiableVyS2fciM : $@yield_once @convention(method) (Float, @inout NotDifferentiable) -> @yields @inout Float

struct Bar: Differentiable {
@noDerivative var stored: Float
}

// Test TF-152: derived conformances "no interface type set" crasher.
struct TF_152: Differentiable {
@differentiable(wrt: bar)
func applied(to input: Float, bar: TF_152_Bar) -> Float {
return input
}
}
struct TF_152_Bar: Differentiable {
@noDerivative let dense: Float
}
88 changes: 88 additions & 0 deletions test/AutoDiff/downstream/nonvarying_semantics.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
// REQUIRES: asserts

// Test approaches for affecting activity analysis (non-varying semantics):
// - `@noDerivative` on declaration
// - `@_semantics("autodiff.nonvarying")` on declaration
// - `withoutDerivative(at:)` at use site

extension Float {
// No non-varying semantics.
var int: Int { Int(self) }

// Non-varying semantics.
@noDerivative
var intNoDerivative: Int { int }

// Non-varying semantics.
@_semantics("autodiff.nonvarying")
var intNonvarying: Int { int }
}

// expected-error @+1 {{function is not differentiable}}
@differentiable
@_silgen_name("id")
// expected-note @+1 {{when differentiating this function definition}}
func id(_ x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
return Float(x.int)
}

// CHECK-LABEL: [AD] Activity info for id at (source=0 parameters=(0))
// CHECK: bb0:
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
// CHECK: [NONE] // function_ref Float.int.getter
// CHECK: [ACTIVE] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
// CHECK: [NONE] // function_ref Float.init(_:)
// CHECK: [ACTIVE] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float

@differentiable
@_silgen_name("idWithoutDerivativeAt")
func idWithoutDerivativeAt(_ x: Float) -> Float {
return Float(withoutDerivative(at: x.int))
}

// CHECK-LABEL: [AD] Activity info for idWithoutDerivativeAt at (source=0 parameters=(0))
// CHECK: bb0:
// CHECK: [VARIED] %0 = argument of bb0 : $Float
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
// CHECK: [USEFUL] %3 = alloc_stack $Int
// CHECK: [NONE] // function_ref Float.int.getter
// CHECK: [VARIED] %5 = apply %4(%0) : $@convention(method) (Float) -> Int
// CHECK: [VARIED] %6 = alloc_stack $Int
// CHECK: [NONE] // function_ref withoutDerivative<A>(at:)
// CHECK: [NONE] %9 = apply %8<Int>(%3, %6) : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0) -> @out τ_0_0
// CHECK: [USEFUL] %11 = load [trivial] %3 : $*Int
// CHECK: [NONE] // function_ref Float.init(_:)
// CHECK: [USEFUL] %13 = apply %12(%11, %2) : $@convention(method) (Int, @thin Float.Type) -> Float

@differentiable
@_silgen_name("idNoDerivative")
func idNoDerivative(_ x: Float) -> Float {
return Float(x.intNoDerivative)
}

// CHECK-LABEL: [AD] Activity info for idNoDerivative at (source=0 parameters=(0))
// CHECK: bb0:
// CHECK: [VARIED] %0 = argument of bb0 : $Float
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
// CHECK: [NONE] // function_ref Float.intNoDerivative.getter
// CHECK: [USEFUL] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
// CHECK: [NONE] // function_ref Float.init(_:)
// CHECK: [USEFUL] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float

@differentiable
@_silgen_name("idNonvaryingSemantics")
func idNonvaryingSemantics(_ x: Float) -> Float {
return Float(x.intNonvarying)
}

// CHECK-LABEL: [AD] Activity info for idNonvaryingSemantics at (source=0 parameters=(0))
// CHECK: bb0:
// CHECK: [VARIED] %0 = argument of bb0 : $Float
// CHECK: [USEFUL] %2 = metatype $@thin Float.Type
// CHECK: [NONE] // function_ref Float.intNonvarying.getter
// CHECK: [USEFUL] %4 = apply %3(%0) : $@convention(method) (Float) -> Int
// CHECK: [NONE] // function_ref Float.init(_:)
// CHECK: [USEFUL] %6 = apply %5(%4, %2) : $@convention(method) (Int, @thin Float.Type) -> Float
2 changes: 2 additions & 0 deletions test/IDE/complete_decl_attribute.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct MyStruct {}
// KEYWORD2-NEXT: Keyword/None: quoted[#Func Attribute#]; name=quoted
// KEYWORD2-NEXT: Keyword/None: differentiating[#Func Attribute#]; name=differentiating
// KEYWORD2-NEXT: Keyword/None: compilerEvaluable[#Func Attribute#]; name=compilerEvaluable
// KEYWORD2-NEXT: Keyword/None: noDerivative[#Func Attribute#]; name=noDerivative
// SWIFT_ENABLE_TENSORFLOW END
// KEYWORD2-NOT: Keyword
// KEYWORD2: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
Expand Down Expand Up @@ -193,6 +194,7 @@ struct _S {
// ON_METHOD-DAG: Keyword/None: quoted[#Func Attribute#]; name=quoted
// ON_METHOD-DAG: Keyword/None: differentiating[#Func Attribute#]; name=differentiating
// ON_METHOD-DAG: Keyword/None: compilerEvaluable[#Func Attribute#]; name=compilerEvaluable
// ON_METHOD-DAG: Keyword/None: noDerivative[#Func Attribute#]; name=noDerivative
// SWIFT_ENABLE_TENSORFLOW END
// ON_METHOD-NOT: Keyword
// ON_METHOD: Decl[Struct]/CurrModule: MyStruct[#MyStruct#]; name=MyStruct
Expand Down
2 changes: 1 addition & 1 deletion test/Sema/class_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ final class VectorSpaceCustomStruct : DummyAdditiveArithmetic, Differentiable {
}

class StaticNoDerivative : Differentiable {
@noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
@noDerivative static var s: Bool = true
}

final class StaticMembersShouldNotAffectAnything : DummyAdditiveArithmetic, Differentiable {
Expand Down
2 changes: 1 addition & 1 deletion test/Sema/struct_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable {
}

struct StaticNoDerivative : Differentiable {
@noDerivative static var s: Bool = true // expected-error {{'@noDerivative' is only allowed on stored properties in structure or class types that declare a conformance to 'Differentiable'}}
@noDerivative static var s: Bool = true
}

struct StaticMembersShouldNotAffectAnything : AdditiveArithmetic, Differentiable {
Expand Down