|
| 1 | +// SWIFT_ENABLE_TENSORFLOW |
| 2 | + |
| 3 | +// RUN: %empty-directory(%t) |
| 4 | +// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t |
| 5 | +// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER |
| 6 | +// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s |
| 7 | +// REQUIRES: differentiable_programming |
| 8 | + |
| 9 | +// TODO(TF-836): Enable this test. |
| 10 | +// Blocked by TF-828: `@differentiating` attribute type-checking. |
| 11 | +// XFAIL: * |
| 12 | + |
| 13 | +// BCANALYZER-NOT: UnknownCode |
| 14 | + |
| 15 | +import _Differentiation |
| 16 | + |
| 17 | +// CHECK: @differentiable(wrt: x, jvp: jvpSimple, vjp: vjpSimple) |
| 18 | +// CHECK-NEXT: func simple(x: Float) -> Float |
| 19 | +@differentiable(jvp: jvpSimple, vjp: vjpSimple) |
| 20 | +func simple(x: Float) -> Float { |
| 21 | + return x |
| 22 | +} |
| 23 | + |
| 24 | +// CHECK: @differentiable(linear, wrt: x) |
| 25 | +// CHECK-NEXT: func simple2(x: Float) -> Float |
| 26 | +@differentiable(linear) |
| 27 | +func simple2(x: Float) -> Float { |
| 28 | + return x |
| 29 | +} |
| 30 | + |
| 31 | +// CHECK: @differentiable(linear, wrt: x) |
| 32 | +// CHECK-NEXT: func simple4(x: Float) -> Float |
| 33 | +@differentiable(linear, wrt: x) |
| 34 | +func simple4(x: Float) -> Float { |
| 35 | + return x |
| 36 | +} |
| 37 | + |
| 38 | +func jvpSimple(x: Float) -> (Float, (Float) -> Float) { |
| 39 | + return (x, { v in v }) |
| 40 | +} |
| 41 | + |
| 42 | +func vjpSimple(x: Float) -> (Float, (Float) -> Float) { |
| 43 | + return (x, { v in v }) |
| 44 | +} |
| 45 | + |
| 46 | +// CHECK: @differentiable(wrt: x) |
| 47 | +// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float |
| 48 | +@differentiable(wrt: x) |
| 49 | +func testWrtClause(x: Float, y: Float) -> Float { |
| 50 | + return x + y |
| 51 | +} |
| 52 | + |
| 53 | +struct InstanceMethod : Differentiable { |
| 54 | + // CHECK: @differentiable(wrt: (self, y)) |
| 55 | + // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float |
| 56 | + @differentiable(wrt: (self, y)) |
| 57 | + func testWrtClause(x: Float, y: Float) -> Float { |
| 58 | + return x + y |
| 59 | + } |
| 60 | + |
| 61 | + struct TangentVector: Differentiable, AdditiveArithmetic { |
| 62 | + typealias TangentVector = Self |
| 63 | + static func ==(_: Self, _: Self) -> Bool { fatalError() } |
| 64 | + static var zero: Self { fatalError() } |
| 65 | + static func +(_: Self, _: Self) -> Self { fatalError() } |
| 66 | + static func -(_: Self, _: Self) -> Self { fatalError() } |
| 67 | + } |
| 68 | + mutating func move(along direction: TangentVector) {} |
| 69 | +} |
| 70 | + |
| 71 | +// CHECK: @differentiable(wrt: x where T : Differentiable) |
| 72 | +// CHECK-NEXT: func testOnlyWhereClause<T>(x: T) -> T where T : Numeric |
| 73 | +@differentiable(where T : Differentiable) |
| 74 | +func testOnlyWhereClause<T : Numeric>(x: T) -> T { |
| 75 | + return x |
| 76 | +} |
| 77 | + |
| 78 | +// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClause where T : Differentiable) |
| 79 | +// CHECK-NEXT: func testWhereClause<T>(x: T) -> T where T : Numeric |
| 80 | +@differentiable(vjp: vjpTestWhereClause where T : Differentiable) |
| 81 | +func testWhereClause<T : Numeric>(x: T) -> T { |
| 82 | + return x |
| 83 | +} |
| 84 | +func vjpTestWhereClause<T>(x: T) -> (T, (T.TangentVector) -> T.TangentVector) |
| 85 | + where T : Numeric, T : Differentiable |
| 86 | +{ |
| 87 | + return (x, { v in v }) |
| 88 | +} |
| 89 | + |
| 90 | +protocol P {} |
| 91 | +extension P { |
| 92 | + // CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethod where Self : Differentiable) |
| 93 | + // CHECK-NEXT: func testWhereClauseMethod() -> Self |
| 94 | + @differentiable(wrt: self, vjp: vjpTestWhereClauseMethod where Self : Differentiable) |
| 95 | + func testWhereClauseMethod() -> Self { |
| 96 | + return self |
| 97 | + } |
| 98 | +} |
| 99 | +extension P where Self : Differentiable { |
| 100 | + func vjpTestWhereClauseMethod() -> (Self, (Self.TangentVector) -> Self.TangentVector) { |
| 101 | + return (self, { v in v }) |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector) |
| 106 | +// CHECK-NEXT: func testWhereClauseMethodTypeConstraint<T>(x: T) -> T where T : Numeric |
| 107 | +@differentiable(vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector) |
| 108 | +func testWhereClauseMethodTypeConstraint<T : Numeric>(x: T) -> T { |
| 109 | + return x |
| 110 | +} |
| 111 | +func vjpTestWhereClauseMethodTypeConstraint<T>(x: T) -> (T, (T) -> T) |
| 112 | + where T : Numeric, T : Differentiable, T == T.TangentVector |
| 113 | +{ |
| 114 | + return (x, { v in v }) |
| 115 | +} |
| 116 | + |
| 117 | +extension P { |
| 118 | + // CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self : Differentiable, Self == Self.TangentVector) |
| 119 | + // CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self |
| 120 | + @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self.TangentVector == Self, Self : Differentiable) |
| 121 | + func testWhereClauseMethodTypeConstraint() -> Self { |
| 122 | + return self |
| 123 | + } |
| 124 | +} |
| 125 | +extension P where Self : Differentiable, Self == Self.TangentVector { |
| 126 | + func vjpTestWhereClauseMethodTypeConstraint() -> (Self, (Self.TangentVector) -> Self.TangentVector) { |
| 127 | + return (self, { v in v }) |
| 128 | + } |
| 129 | +} |
0 commit comments