|
| 1 | +// Pullback generation tests written in SIL for features |
| 2 | +// that may not be directly supported by the Swift frontend |
| 3 | + |
| 4 | +// RUN: %target-sil-opt --differentiation -debug-only=differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s |
| 5 | + |
| 6 | +//===----------------------------------------------------------------------===// |
| 7 | +// Pullback generation - `struct_extract` |
| 8 | +// - Input to pullback has non-owned ownership semantics which requires copying |
| 9 | +// this value to stack before lifetime-ending uses. |
| 10 | +//===----------------------------------------------------------------------===// |
| 11 | + |
| 12 | +sil_stage raw |
| 13 | + |
| 14 | +import Builtin |
| 15 | +import Swift |
| 16 | +import SwiftShims |
| 17 | + |
| 18 | +import _Differentiation |
| 19 | + |
| 20 | +struct X { |
| 21 | + @_hasStorage var a: Float { get set } |
| 22 | + @_hasStorage var b: String { get set } |
| 23 | + init(a: Float, b: String) |
| 24 | +} |
| 25 | + |
| 26 | +extension X : Differentiable, Equatable, AdditiveArithmetic { |
| 27 | + public typealias TangentVector = X |
| 28 | + mutating func move(by offset: X) |
| 29 | + public static var zero: X { get } |
| 30 | + public static func + (lhs: X, rhs: X) -> X |
| 31 | + public static func - (lhs: X, rhs: X) -> X |
| 32 | + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool |
| 33 | +} |
| 34 | + |
| 35 | +struct Y { |
| 36 | + @_hasStorage var a: X { get set } |
| 37 | + @_hasStorage var b: String { get set } |
| 38 | + init(a: X, b: String) |
| 39 | +} |
| 40 | + |
| 41 | +extension Y : Differentiable, Equatable, AdditiveArithmetic { |
| 42 | + public typealias TangentVector = Y |
| 43 | + mutating func move(by offset: Y) |
| 44 | + public static var zero: Y { get } |
| 45 | + public static func + (lhs: Y, rhs: Y) -> Y |
| 46 | + public static func - (lhs: Y, rhs: Y) -> Y |
| 47 | + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool |
| 48 | +} |
| 49 | + |
| 50 | +sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X { |
| 51 | +} |
| 52 | + |
| 53 | +sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X { |
| 54 | +bb0(%0 : @guaranteed $Y): |
| 55 | + %1 = struct_extract %0 : $Y, #Y.a |
| 56 | + %2 = copy_value %1 : $X |
| 57 | + return %2 : $X |
| 58 | +} |
| 59 | + |
| 60 | +// CHECK-LABEL: [ORIG] %1 = struct_extract %0 : $Y, #Y.a // user: %2 |
| 61 | +// CHECK: [ADJ] Emitted in pullback (pb bb0): |
| 62 | +// CHECK: %1 = alloc_stack $Y // users: {{.*}} |
| 63 | +// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %4 |
| 64 | +// CHECK: %3 = metatype $@thick Y.Type // user: %4 |
| 65 | +// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 66 | + |
| 67 | +// Since input parameter $0 has non-owned ownership semantics, it |
| 68 | +// needs to be copied before a lifetime-ending use. |
| 69 | +// CHECK: %5 = copy_value %0 : $X // user: %7 |
| 70 | + |
| 71 | +// CHECK: %6 = alloc_stack $X // users: {{.*}} |
| 72 | +// CHECK: store %5 to [init] %6 : $*X // id: %7 |
| 73 | +// CHECK: %8 = struct_element_addr %1 : $*Y, #Y.a // user: %11 |
| 74 | +// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %11 |
| 75 | +// CHECK: %10 = metatype $@thick X.Type // user: %11 |
| 76 | +// CHECK: %11 = apply %9<X>(%8, %6, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () |
| 77 | +// CHECK: %12 = load [take] %1 : $*Y |
| 78 | +// CHECK: destroy_addr %6 : $*X // id: %13 |
| 79 | +// CHECK: dealloc_stack %6 : $*X // id: %14 |
| 80 | +// CHECK: dealloc_stack %1 : $*Y // id: %15 |
| 81 | + |
| 82 | +//===----------------------------------------------------------------------===// |
| 83 | +// Pullback generation - `tuple_extract` |
| 84 | +// - Tuples as differentiable input arguments are not supported yet, so creating |
| 85 | +// a basic test in SIL instead. |
| 86 | +//===----------------------------------------------------------------------===// |
| 87 | + |
| 88 | +sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float { |
| 89 | +} |
| 90 | + |
| 91 | +sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float { |
| 92 | +bb0(%0 : $(Float, Float)): |
| 93 | + %1 = tuple_extract %0 : $(Float, Float), 0 |
| 94 | + return %1 : $Float |
| 95 | +} |
| 96 | + |
| 97 | +// CHECK-LABEL: [ORIG] %1 = tuple_extract %0 : $(Float, Float), 0 // user: %2 |
| 98 | +// CHECK: [ADJ] Emitted in pullback (pb bb0): |
| 99 | +// CHECK: %1 = alloc_stack $(Float, Float) // users: {{.*}} |
| 100 | +// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %5 |
| 101 | +// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5 |
| 102 | +// CHECK: %4 = metatype $@thick Float.Type // user: %5 |
| 103 | +// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 104 | +// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1 // user: %9 |
| 105 | +// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9 |
| 106 | +// CHECK: %8 = metatype $@thick Float.Type // user: %9 |
| 107 | +// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 108 | +// CHECK: %10 = alloc_stack $Float // users: {{.*}} |
| 109 | +// CHECK: store %0 to [trivial] %10 : $*Float // id: %11 |
| 110 | +// CHECK: %12 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %15 |
| 111 | +// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %15 |
| 112 | +// CHECK: %14 = metatype $@thick Float.Type // user: %15 |
| 113 | +// CHECK: %15 = apply %13<Float>(%12, %10, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () |
| 114 | +// CHECK: %16 = load [trivial] %1 : $*(Float, Float) |
| 115 | +// CHECK: destroy_addr %10 : $*Float // id: %17 |
| 116 | +// CHECK: dealloc_stack %10 : $*Float // id: %18 |
| 117 | +// CHECK: dealloc_stack %1 : $*(Float, Float) // id: %19 |
| 118 | + |
| 119 | +//===----------------------------------------------------------------------===// |
| 120 | +// Pullback generation - `tuple_extract` |
| 121 | +// - Input to pullback has non-owned ownership semantics which requires copying |
| 122 | +// this value to stack before lifetime-ending uses. |
| 123 | +//===----------------------------------------------------------------------===// |
| 124 | +sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X { |
| 125 | +} |
| 126 | + |
| 127 | +sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X { |
| 128 | +bb0(%0 : @guaranteed $(X, X)): |
| 129 | + %1 = tuple_extract %0 : $(X, X), 0 |
| 130 | + %2 = copy_value %1: $X |
| 131 | + return %2 : $X |
| 132 | +} |
| 133 | + |
| 134 | +// CHECK-LABEL: [ORIG] %1 = tuple_extract %0 : $(X, X), 0 // user: %2 |
| 135 | +// CHECK: [ADJ] Emitted in pullback (pb bb0): |
| 136 | +// CHECK: %1 = alloc_stack $(X, X) // users: {{.*}} |
| 137 | +// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0 // user: %5 |
| 138 | +// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5 |
| 139 | +// CHECK: %4 = metatype $@thick X.Type // user: %5 |
| 140 | +// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 141 | +// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1 // user: %9 |
| 142 | +// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9 |
| 143 | +// CHECK: %8 = metatype $@thick X.Type // user: %9 |
| 144 | +// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 145 | +// CHECK: %10 = copy_value %0 : $X // user: %12 |
| 146 | +// CHECK: %11 = alloc_stack $X // users: {{.*}} |
| 147 | +// CHECK: store %10 to [init] %11 : $*X // id: %12 |
| 148 | +// CHECK: %13 = tuple_element_addr %1 : $*(X, X), 0 // user: %16 |
| 149 | +// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %16 |
| 150 | +// CHECK: %15 = metatype $@thick X.Type // user: %16 |
| 151 | +// CHECK: %16 = apply %14<X>(%13, %11, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () |
| 152 | +// CHECK: %17 = load [take] %1 : $*(X, X) |
| 153 | +// CHECK: destroy_addr %11 : $*X // id: %18 |
| 154 | +// CHECK: dealloc_stack %11 : $*X // id: %19 |
| 155 | +// CHECK: dealloc_stack %1 : $*(X, X) // id: %20 |
0 commit comments