Skip to content

Commit 2711934

Browse files
committed
feat: define chain rule for convert rather than eval_tree_array
1 parent 2219c0b commit 2711934

File tree

1 file changed

+31
-68
lines changed

1 file changed

+31
-68
lines changed

src/ParametricExpression.jl

Lines changed: 31 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ParametricExpressionModule
22

33
using DispatchDoctor: @stable, @unstable
4-
using ChainRulesCore: ChainRulesCore, NoTangent
4+
using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
@@ -238,6 +238,36 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
238238
Node{T},
239239
)
240240
end
241+
function ChainRulesCore.rrule(
242+
::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}
243+
) where {T}
244+
tree = get_contents(ex)
245+
primal = convert(Node, ex)
246+
pullback = let tree = tree
247+
d_primal -> let
248+
# ^The exact same tangent with respect to constants, so we can just take it.
249+
d_ex = @thunk(
250+
let
251+
parametric_node_tangent = NodeTangent(tree, d_primal.gradient)
252+
(;
253+
tree=parametric_node_tangent,
254+
metadata=(;
255+
_data=(;
256+
operators=NoTangent(),
257+
variable_names=NoTangent(),
258+
parameters=NoTangent(),
259+
parameter_names=NoTangent(),
260+
)
261+
),
262+
)
263+
end
264+
)
265+
(NoTangent(), NoTangent(), d_ex)
266+
end
267+
end
268+
return primal, pullback
269+
end
270+
241271
#! format: off
242272
function (ex::ParametricExpression)(X::AbstractMatrix, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...)
243273
return eval_tree_array(ex, X, operators; kws...) # Will error
@@ -278,73 +308,6 @@ function eval_tree_array(
278308
regular_tree = convert(Node, ex)
279309
return eval_tree_array(regular_tree, params_and_X, get_operators(ex, operators); kws...)
280310
end
281-
function ChainRulesCore.rrule(
282-
::typeof(eval_tree_array),
283-
ex::ParametricExpression{T},
284-
X::AbstractMatrix{T},
285-
classes::AbstractVector{<:Integer},
286-
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
287-
kws...,
288-
) where {T}
289-
primal, complete = eval_tree_array(ex, X, classes, operators; kws...)
290-
291-
# TODO: Preferable to use the primal in the pullback somehow
292-
function pullback((dY, _))
293-
parameters = ex.metadata.parameters
294-
num_params = size(parameters, 1)
295-
num_classes = size(parameters, 2)
296-
_operators = get_operators(ex, operators)
297-
indexed_parameters = [
298-
parameters[i_parameter, classes[i_row]] for
299-
i_parameter in eachindex(axes(parameters, 1)), i_row in eachindex(classes)
300-
]
301-
params_and_X = vcat(indexed_parameters, X)
302-
tree = ex.tree
303-
regular_tree = convert(Node, ex)
304-
305-
_, gradient_tree, complete1 = eval_grad_tree_array(
306-
regular_tree, params_and_X, _operators; variable=Val(false)
307-
)
308-
_, gradient_params_and_X, complete2 = eval_grad_tree_array(
309-
regular_tree, params_and_X, _operators; variable=Val(true)
310-
)
311-
312-
if !complete1
313-
gradient_tree .= NaN
314-
end
315-
if !complete2
316-
gradient_params_and_X .= NaN
317-
end
318-
319-
d_tree = NodeTangent(
320-
tree,
321-
sum(j -> gradient_tree[:, j] * dY[j], eachindex(dY, axes(gradient_tree, 2))),
322-
)
323-
reshaped_d_Y = reshape(dY, 1, length(dY))
324-
d_indexed_parameters = @view(gradient_params_and_X[1:num_params, :]) .* reshaped_d_Y
325-
d_X = @view(gradient_params_and_X[(num_params + 1):end, :]) .* reshaped_d_Y
326-
d_parameters = [
327-
sum(
328-
j -> d_indexed_parameters[param, j] * dY[j] * (classes[j] == class),
329-
eachindex(classes, axes(d_indexed_parameters, 2)),
330-
) for param in 1:num_params, class in 1:num_classes
331-
]
332-
d_ex = (;
333-
tree=d_tree,
334-
metadata=(;
335-
_data=(;
336-
operators=NoTangent(),
337-
variable_names=NoTangent(),
338-
parameters=d_parameters,
339-
parameter_names=NoTangent(),
340-
),
341-
),
342-
)
343-
return (NoTangent(), d_ex, d_X, NoTangent(), NoTangent())
344-
end
345-
346-
return (primal, complete), pullback
347-
end
348311

349312
function string_tree(
350313
ex::ParametricExpression,

0 commit comments

Comments
 (0)