@@ -139,3 +139,50 @@ bb0(%0 : $Class):
139
139
// CHECK: [[VJP_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_FN]]([[ARG]])
140
140
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN_PARTIALLY_APPLIED]] : {{.*}} with_derivative {[[JVP_FN_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_FN_PARTIALLY_APPLIED]] : {{.*}}}
141
141
// CHECK: }
142
+
143
+ // Test curry thunks.
144
+
145
+ struct S {
146
+ var x: Float
147
+ }
148
+
149
+ sil @method : $@convention(method) (Float, S) -> Float {
150
+ bb0(%0 : $Float, %1 : $S):
151
+ return %0 : $Float
152
+ }
153
+
154
+ sil @method_thin : $@convention(thin) (Float, S) -> Float {
155
+ bb0(%0 : $Float, %1 : $S):
156
+ %4 = function_ref @method : $@convention(method) (Float, S) -> Float
157
+ %5 = apply %4(%0, %1) : $@convention(method) (Float, S) -> Float
158
+ return %5 : $Float
159
+ }
160
+
161
+ sil @method_curried : $@convention(thin) (S) -> @owned @callee_guaranteed (Float) -> Float {
162
+ bb0(%0 : $S):
163
+ %2 = function_ref @method_thin : $@convention(thin) (Float, S) -> Float
164
+ %3 = partial_apply [callee_guaranteed] %2(%0) : $@convention(thin) (Float, S) -> Float
165
+ return %3 : $@callee_guaranteed (Float) -> Float
166
+ }
167
+
168
+ sil @test_curry_thunk : $@convention(thin) (Float, S) -> () {
169
+ bb0(%0 : $Float, %1 : $S):
170
+ %2 = function_ref @method_curried : $@convention(thin) (S) -> @owned @callee_guaranteed (Float) -> Float
171
+ %3 = apply %2(%1) : $@convention(thin) (S) -> @owned @callee_guaranteed (Float) -> Float
172
+ %4 = differentiable_function [parameters 0] [results 0] %3 : $@callee_guaranteed (Float) -> Float
173
+ %5 = tuple ()
174
+ return %5 : $()
175
+ }
176
+
177
+ // CHECK-LABEL: sil {{.*}} @AD__method_curried__differentiable_curry_thunk_src_0_wrt_0
178
+ // CHECK: bb0([[SELF:%.*]] : $S):
179
+ // CHECK: [[METHOD:%.*]] = function_ref @method_thin : $@convention(thin) (Float, S) -> Float
180
+ // CHECK: [[CURRIED:%.*]] = partial_apply [callee_guaranteed] [[METHOD]]([[SELF]]) : $@convention(thin) (Float, S) -> Float
181
+ // CHECK: [[METHOD_JVP:%.*]] = differentiability_witness_function [jvp] [parameters 0] [results 0] @method_thin : $@convention(thin) (Float, S) -> Float
182
+ // CHECK: [[CURRIED_JVP:%.*]] = partial_apply [callee_guaranteed] [[METHOD_JVP]]([[SELF]]) : $@convention(thin) (Float, S) -> (Float, @owned @callee_guaranteed (Float) -> Float)
183
+ // CHECK: [[METHOD_VJP:%.*]] = differentiability_witness_function [vjp] [parameters 0] [results 0] @method_thin : $@convention(thin) (Float, S) -> Float
184
+ // CHECK: [[CURRIED_VJP:%.*]] = partial_apply [callee_guaranteed] [[METHOD_VJP]]([[SELF]]) : $@convention(thin) (Float, S) -> (Float, @owned @callee_guaranteed (Float) -> Float)
185
+ // CHECK: strong_retain [[CURRIED]] : $@callee_guaranteed (Float) -> Float
186
+ // CHECK: [[RESULT:%.*]] = differentiable_function [parameters 0] [results 0] [[CURRIED]] : $@callee_guaranteed (Float) -> Float with_derivative {[[CURRIED_JVP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[CURRIED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
187
+ // CHECK: strong_release [[CURRIED]] : $@callee_guaranteed (Float) -> Float
188
+ // CHECK: return [[RESULT]] : $@differentiable @callee_guaranteed (Float) -> Float
0 commit comments