1
1
module ParametricExpressionModule
2
2
3
3
using DispatchDoctor: @stable , @unstable
4
+ using ChainRulesCore: ChainRulesCore, NoTangent
4
5
5
- using .. OperatorEnumModule: AbstractOperatorEnum
6
+ using .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
6
7
using .. NodeModule: AbstractExpressionNode, Node, tree_mapreduce
7
8
using .. ExpressionModule: AbstractExpression, Metadata
9
+ using .. ChainRulesModule: NodeTangent
8
10
9
11
import .. NodeModule: constructorof, preserve_sharing, leaf_copy, leaf_hash, leaf_equal
10
12
import .. NodeUtilsModule:
@@ -250,7 +252,7 @@ function (ex::ParametricExpression)(
250
252
operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
251
253
kws... ,
252
254
) where {T}
253
- (output, flag) = eval_tree_array (ex, X, classes, operators; kws... ) # Will error
255
+ (output, flag) = eval_tree_array (ex, X, classes, operators; kws... )
254
256
if ! flag
255
257
output .= NaN
256
258
end
@@ -276,6 +278,71 @@ function eval_tree_array(
276
278
regular_tree = convert (Node, ex)
277
279
return eval_tree_array (regular_tree, params_and_X, get_operators (ex, operators); kws... )
278
280
end
281
+ function ChainRulesCore. rrule (
282
+ :: typeof (eval_tree_array),
283
+ ex:: ParametricExpression{T} ,
284
+ X:: AbstractMatrix{T} ,
285
+ classes:: AbstractVector{<:Integer} ,
286
+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
287
+ kws... ,
288
+ ) where {T}
289
+ primal, complete = eval_tree_array (ex, X, classes, operators; kws... )
290
+
291
+ # TODO : Preferable to use the primal in the pullback somehow
292
+ function pullback ((dY, _))
293
+ parameters = ex. metadata. parameters
294
+ num_params = size (parameters, 1 )
295
+ num_classes = size (parameters, 2 )
296
+ indexed_parameters = [
297
+ parameters[i_parameter, classes[i_row]] for
298
+ i_parameter in eachindex (axes (parameters, 1 )), i_row in eachindex (classes)
299
+ ]
300
+ params_and_X = vcat (indexed_parameters, X)
301
+ tree = ex. tree
302
+ regular_tree = convert (Node, ex)
303
+
304
+ _, gradient_tree, complete1 = eval_grad_tree_array (
305
+ regular_tree, params_and_X, operators; variable= Val (false )
306
+ )
307
+ _, gradient_params_and_X, complete2 = eval_grad_tree_array (
308
+ regular_tree, params_and_X, operators; variable= Val (true )
309
+ )
310
+
311
+ if ! complete1
312
+ gradient_tree .= NaN
313
+ end
314
+ if ! complete2
315
+ gradient_params_and_X .= NaN
316
+ end
317
+
318
+ d_tree = NodeTangent (
319
+ tree,
320
+ sum (j -> gradient_tree[:, j] * dY[j], eachindex (dY, axes (gradient_tree, 2 ))),
321
+ )
322
+ reshaped_d_Y = reshape (dY, 1 , length (dY))
323
+ d_indexed_parameters = @view (gradient_params_and_X[1 : num_params, :]) .* reshaped_d_Y
324
+ d_X = @view (gradient_params_and_X[(num_params + 1 ): end , :]) .* reshaped_d_Y
325
+ d_parameters = [
326
+ sum (
327
+ j -> d_indexed_parameters[param, j] * dY[j] * (classes[j] == class),
328
+ eachindex (classes, axes (d_indexed_parameters, 2 )),
329
+ ) for param in 1 : num_params, class in 1 : num_classes
330
+ ]
331
+ d_ex = (;
332
+ tree= d_tree,
333
+ metadata= (;
334
+ operators= NoTangent (),
335
+ variable_names= NoTangent (),
336
+ parameters= d_parameters,
337
+ parameter_names= NoTangent (),
338
+ ),
339
+ )
340
+ return (NoTangent (), d_ex, copy (d_X), NoTangent (), NoTangent ())
341
+ end
342
+
343
+ return (primal, complete), pullback
344
+ end
345
+
279
346
function string_tree (
280
347
ex:: ParametricExpression ,
281
348
operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
0 commit comments