Skip to content

Commit bcfc360

Browse files
committed
feat: add chain rules for parametric expression
1 parent 9e95f05 commit bcfc360

File tree

2 files changed

+70
-3
lines changed

2 files changed

+70
-3
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: Abstra
1212
gradient::A
1313
end
1414
function Base.:+(a::NodeTangent, b::NodeTangent)
15-
@assert a.tree == b.tree
15+
@assert a.tree == b.tree # TODO: Remove this check
1616
return NodeTangent(a.tree, a.gradient + b.gradient)
1717
end
1818
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)

src/ParametricExpression.jl

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module ParametricExpressionModule
22

33
using DispatchDoctor: @stable, @unstable
4+
using ChainRulesCore: ChainRulesCore, NoTangent
45

5-
using ..OperatorEnumModule: AbstractOperatorEnum
6+
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
67
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
78
using ..ExpressionModule: AbstractExpression, Metadata
9+
using ..ChainRulesModule: NodeTangent
810

911
import ..NodeModule: constructorof, preserve_sharing, leaf_copy, leaf_hash, leaf_equal
1012
import ..NodeUtilsModule:
@@ -250,7 +252,7 @@ function (ex::ParametricExpression)(
250252
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
251253
kws...,
252254
) where {T}
253-
(output, flag) = eval_tree_array(ex, X, classes, operators; kws...) # Will error
255+
(output, flag) = eval_tree_array(ex, X, classes, operators; kws...)
254256
if !flag
255257
output .= NaN
256258
end
@@ -276,6 +278,71 @@ function eval_tree_array(
276278
regular_tree = convert(Node, ex)
277279
return eval_tree_array(regular_tree, params_and_X, get_operators(ex, operators); kws...)
278280
end
281+
function ChainRulesCore.rrule(
282+
::typeof(eval_tree_array),
283+
ex::ParametricExpression{T},
284+
X::AbstractMatrix{T},
285+
classes::AbstractVector{<:Integer},
286+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
287+
kws...,
288+
) where {T}
289+
primal, complete = eval_tree_array(ex, X, classes, operators; kws...)
290+
291+
# TODO: Preferable to use the primal in the pullback somehow
292+
function pullback((dY, _))
293+
parameters = ex.metadata.parameters
294+
num_params = size(parameters, 1)
295+
num_classes = size(parameters, 2)
296+
indexed_parameters = [
297+
parameters[i_parameter, classes[i_row]] for
298+
i_parameter in eachindex(axes(parameters, 1)), i_row in eachindex(classes)
299+
]
300+
params_and_X = vcat(indexed_parameters, X)
301+
tree = ex.tree
302+
regular_tree = convert(Node, ex)
303+
304+
_, gradient_tree, complete1 = eval_grad_tree_array(
305+
regular_tree, params_and_X, operators; variable=Val(false)
306+
)
307+
_, gradient_params_and_X, complete2 = eval_grad_tree_array(
308+
regular_tree, params_and_X, operators; variable=Val(true)
309+
)
310+
311+
if !complete1
312+
gradient_tree .= NaN
313+
end
314+
if !complete2
315+
gradient_params_and_X .= NaN
316+
end
317+
318+
d_tree = NodeTangent(
319+
tree,
320+
sum(j -> gradient_tree[:, j] * dY[j], eachindex(dY, axes(gradient_tree, 2))),
321+
)
322+
reshaped_d_Y = reshape(dY, 1, length(dY))
323+
d_indexed_parameters = @view(gradient_params_and_X[1:num_params, :]) .* reshaped_d_Y
324+
d_X = @view(gradient_params_and_X[(num_params + 1):end, :]) .* reshaped_d_Y
325+
d_parameters = [
326+
sum(
327+
j -> d_indexed_parameters[param, j] * dY[j] * (classes[j] == class),
328+
eachindex(classes, axes(d_indexed_parameters, 2)),
329+
) for param in 1:num_params, class in 1:num_classes
330+
]
331+
d_ex = (;
332+
tree=d_tree,
333+
metadata=(;
334+
operators=NoTangent(),
335+
variable_names=NoTangent(),
336+
parameters=d_parameters,
337+
parameter_names=NoTangent(),
338+
),
339+
)
340+
return (NoTangent(), d_ex, copy(d_X), NoTangent(), NoTangent())
341+
end
342+
343+
return (primal, complete), pullback
344+
end
345+
279346
function string_tree(
280347
ex::ParametricExpression,
281348
operators::Union{AbstractOperatorEnum,Nothing}=nothing;

0 commit comments

Comments
 (0)