|
| 1 | +// RUN: %target-swift-frontend -emit-sil -O -g %s | %FileCheck %s |
| 2 | + |
| 3 | +// REQUIRES: swift_in_compiler |
| 4 | + |
| 5 | +// Fix for https://github.com/apple/swift/issues/62608 |
| 6 | +// We need to emit separate debug info location for different adjoint buffers |
| 7 | +// created for the single input variable |
| 8 | + |
| 9 | +import _Differentiation |
| 10 | + |
| 11 | +public extension Array { |
| 12 | + @inlinable |
| 13 | + @differentiable(reverse) |
| 14 | + mutating func update(at index: Int, byCalling closure: @differentiable(reverse) (inout Element) -> Void) where Element: Differentiable { |
| 15 | + closure(&self[index]) |
| 16 | + } |
| 17 | +} |
| 18 | + |
| 19 | +public func valueWithPullback<T>( |
| 20 | + at x: T, of f: @differentiable(reverse) (inout T) -> Void |
| 21 | +) -> (value: Void, pullback: (inout T.TangentVector) -> Void) { |
| 22 | + @differentiable(reverse) |
| 23 | + func nonInoutWrappingFunction(_ t: T) -> T { |
| 24 | + var t = t |
| 25 | + f(&t) |
| 26 | + return t |
| 27 | + } |
| 28 | + let nonInoutPullback = pullback(at: x, of: nonInoutWrappingFunction) |
| 29 | + return ((), { $0 = nonInoutPullback($0) }) |
| 30 | +} |
| 31 | + |
| 32 | +@inlinable |
| 33 | +public func pullback<T>( |
| 34 | + at x: T, of f: @differentiable(reverse) (inout T) -> Void |
| 35 | +) -> (inout T.TangentVector) -> Void { |
| 36 | + return valueWithPullback(at: x, of: f).pullback |
| 37 | +} |
| 38 | + |
| 39 | +// CHECK-LABEL: sil private @$s4main19testUpdateByCallingyyKF8fOfArrayL_5arraySdSaySdG_tFySdzcfU_TJpSUpSr : |
| 40 | +// CHECK: alloc_stack $Double, var, name "derivative of 'element' in scope at {{.*}} (scope #3)" |
| 41 | +// CHECK: debug_value %{{.*}} : $Builtin.FPIEEE64, var, (name "derivative of 'element' in scope at {{.*}} (scope #1)" |
| 42 | + |
| 43 | +public extension Array where Element: Differentiable { |
| 44 | + @inlinable |
| 45 | + @derivative(of: update(at:byCalling:)) |
| 46 | + mutating func vjpUpdate( |
| 47 | + at index: Int, |
| 48 | + byCalling closure: @differentiable(reverse) (inout Element) -> Void |
| 49 | + ) |
| 50 | + -> |
| 51 | + (value: Void, pullback: (inout Self.TangentVector) -> Void) |
| 52 | + { |
| 53 | + let closurePullback = pullback(at: self[index], of: closure) |
| 54 | + return (value: (), pullback: { closurePullback(&$0.base[index]) }) |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +func testUpdateByCalling() throws { |
| 59 | + @differentiable(reverse) |
| 60 | + func fOfArray(array: [Double]) -> Double { |
| 61 | + var array = array |
| 62 | + var result = 0.0 |
| 63 | + for i in withoutDerivative(at: 0 ..< array.count) { |
| 64 | + array.update(at: i, byCalling: { (element: inout Double) in |
| 65 | + let initialElement = element |
| 66 | + for _ in withoutDerivative(at: 0 ..< i) { |
| 67 | + element *= initialElement |
| 68 | + } |
| 69 | + }) |
| 70 | + result += array[i] |
| 71 | + } |
| 72 | + return result |
| 73 | + } |
| 74 | + |
| 75 | + let array = [Double](repeating: 1.0, count: 3) |
| 76 | + let expectedGradientOfFOfArray = [1.0, 2.0, 3.0] |
| 77 | + let obtainedGradientOfFOfArray = gradient(at: array, of: fOfArray).base |
| 78 | +} |
| 79 | + |
0 commit comments