@@ -128,3 +128,73 @@ func f3(v: NonPiecewiseMaterializableWithAggDifferentiableField) -> PiecewiseMat
128
128
// CHECK: destroy_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField
129
129
// CHECK: return %16 : $NonPiecewiseMaterializableWithAggDifferentiableField
130
130
// CHECK: } // end sil function '$s19pullback_generation2f31vAA23PiecewiseMaterializableVAA03NondE26WithAggDifferentiableFieldV_tFTJpSpSr'
131
+
132
+ //===----------------------------------------------------------------------===//
133
+ // Pullback generation - `struct_extract`
134
+ // - Adjoint of extracted element can be `AddElement`
135
+ // - Just need to make sure that we are able to generate a pullback for B.x's
136
+ // getter
137
+ //===----------------------------------------------------------------------===//
138
+ struct A : Differentiable {
139
+ public var x : Float
140
+ }
141
+
142
+ struct B : Differentiable {
143
+ var y : A
144
+
145
+ public init ( a: A ) {
146
+ self . y = a
147
+ }
148
+
149
+ @differentiable ( reverse)
150
+ public var x : Float {
151
+ get { return self . y. x }
152
+ }
153
+ }
154
+
155
+ // CHECK-LABEL: sil private [ossa] @$s19pullback_generation1BV1xSfvgTJpSpSr : $@convention(thin) (Float) -> B.TangentVector {
156
+
157
+ //===----------------------------------------------------------------------===//
158
+ // Pullback generation - Inner values of concrete adjoints must be copied
159
+ // during indirect materialization
160
+ //===----------------------------------------------------------------------===//
161
+
162
+ struct NonTrivial {
163
+ var x : Float
164
+ var y : String
165
+ }
166
+
167
+ extension NonTrivial : Differentiable , Equatable , AdditiveArithmetic {
168
+ public typealias TangentVector = Self
169
+ mutating func move( by offset: TangentVector ) { fatalError ( ) }
170
+ public static var zero : Self { fatalError ( ) }
171
+ public static func + ( lhs: Self , rhs: Self ) -> Self { fatalError ( ) }
172
+ public static func - ( lhs: Self , rhs: Self ) -> Self { fatalError ( ) }
173
+ }
174
+
175
+ @differentiable ( reverse)
176
+ func f4( a: NonTrivial ) -> Float {
177
+ var sum : Float = 0
178
+ for _ in 0 ..< 1 {
179
+ sum += a. x
180
+ }
181
+ return sum
182
+ }
183
+
184
+ // CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
185
+ // CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
186
+ // CHECK: %88 = alloc_stack $NonTrivial
187
+
188
+ // Non-trivial value must be copied or there will be a
189
+ // double consume when all owned parameters are destroyed
190
+ // at the end of the basic block.
191
+ // CHECK: %89 = copy_value %67 : $NonTrivial
192
+
193
+ // CHECK: store %89 to [init] %88 : $*NonTrivial
194
+ // CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
195
+ // CHECK: %92 = alloc_stack $Float
196
+ // CHECK: store %86 to [trivial] %92 : $*Float
197
+ // CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
198
+ // CHECK: %95 = metatype $@thick Float.Type
199
+ // CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200
+ // CHECK: destroy_value %67 : $NonTrivial
0 commit comments