|
1 | 1 | module ParametricExpressionModule
|
2 | 2 |
|
3 | 3 | using DispatchDoctor: @stable, @unstable
|
4 |
| -using ChainRulesCore: ChainRulesCore, NoTangent |
| 4 | +using ChainRulesCore: ChainRulesCore, NoTangent, @thunk |
5 | 5 |
|
6 | 6 | using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
|
7 | 7 | using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
|
@@ -238,6 +238,36 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
|
238 | 238 | Node{T},
|
239 | 239 | )
|
240 | 240 | end
|
| 241 | +function ChainRulesCore.rrule( |
| 242 | + ::typeof(convert), ::Type{Node}, ex::ParametricExpression{T} |
| 243 | +) where {T} |
| 244 | + tree = get_contents(ex) |
| 245 | + primal = convert(Node, ex) |
| 246 | + pullback = let tree = tree |
| 247 | + d_primal -> let |
| 248 | + # ^The exact same tangent with respect to constants, so we can just take it. |
| 249 | + d_ex = @thunk( |
| 250 | + let |
| 251 | + parametric_node_tangent = NodeTangent(tree, d_primal.gradient) |
| 252 | + (; |
| 253 | + tree=parametric_node_tangent, |
| 254 | + metadata=(; |
| 255 | + _data=(; |
| 256 | + operators=NoTangent(), |
| 257 | + variable_names=NoTangent(), |
| 258 | + parameters=NoTangent(), |
| 259 | + parameter_names=NoTangent(), |
| 260 | + ) |
| 261 | + ), |
| 262 | + ) |
| 263 | + end |
| 264 | + ) |
| 265 | + (NoTangent(), NoTangent(), d_ex) |
| 266 | + end |
| 267 | + end |
| 268 | + return primal, pullback |
| 269 | +end |
| 270 | + |
241 | 271 | #! format: off
|
242 | 272 | function (ex::ParametricExpression)(X::AbstractMatrix, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...)
|
243 | 273 | return eval_tree_array(ex, X, operators; kws...) # Will error
|
@@ -278,73 +308,6 @@ function eval_tree_array(
|
278 | 308 | regular_tree = convert(Node, ex)
|
279 | 309 | return eval_tree_array(regular_tree, params_and_X, get_operators(ex, operators); kws...)
|
280 | 310 | end
|
281 |
| -function ChainRulesCore.rrule( |
282 |
| - ::typeof(eval_tree_array), |
283 |
| - ex::ParametricExpression{T}, |
284 |
| - X::AbstractMatrix{T}, |
285 |
| - classes::AbstractVector{<:Integer}, |
286 |
| - operators::Union{AbstractOperatorEnum,Nothing}=nothing; |
287 |
| - kws..., |
288 |
| -) where {T} |
289 |
| - primal, complete = eval_tree_array(ex, X, classes, operators; kws...) |
290 |
| - |
291 |
| - # TODO: Preferable to use the primal in the pullback somehow |
292 |
| - function pullback((dY, _)) |
293 |
| - parameters = ex.metadata.parameters |
294 |
| - num_params = size(parameters, 1) |
295 |
| - num_classes = size(parameters, 2) |
296 |
| - _operators = get_operators(ex, operators) |
297 |
| - indexed_parameters = [ |
298 |
| - parameters[i_parameter, classes[i_row]] for |
299 |
| - i_parameter in eachindex(axes(parameters, 1)), i_row in eachindex(classes) |
300 |
| - ] |
301 |
| - params_and_X = vcat(indexed_parameters, X) |
302 |
| - tree = ex.tree |
303 |
| - regular_tree = convert(Node, ex) |
304 |
| - |
305 |
| - _, gradient_tree, complete1 = eval_grad_tree_array( |
306 |
| - regular_tree, params_and_X, _operators; variable=Val(false) |
307 |
| - ) |
308 |
| - _, gradient_params_and_X, complete2 = eval_grad_tree_array( |
309 |
| - regular_tree, params_and_X, _operators; variable=Val(true) |
310 |
| - ) |
311 |
| - |
312 |
| - if !complete1 |
313 |
| - gradient_tree .= NaN |
314 |
| - end |
315 |
| - if !complete2 |
316 |
| - gradient_params_and_X .= NaN |
317 |
| - end |
318 |
| - |
319 |
| - d_tree = NodeTangent( |
320 |
| - tree, |
321 |
| - sum(j -> gradient_tree[:, j] * dY[j], eachindex(dY, axes(gradient_tree, 2))), |
322 |
| - ) |
323 |
| - reshaped_d_Y = reshape(dY, 1, length(dY)) |
324 |
| - d_indexed_parameters = @view(gradient_params_and_X[1:num_params, :]) .* reshaped_d_Y |
325 |
| - d_X = @view(gradient_params_and_X[(num_params + 1):end, :]) .* reshaped_d_Y |
326 |
| - d_parameters = [ |
327 |
| - sum( |
328 |
| - j -> d_indexed_parameters[param, j] * dY[j] * (classes[j] == class), |
329 |
| - eachindex(classes, axes(d_indexed_parameters, 2)), |
330 |
| - ) for param in 1:num_params, class in 1:num_classes |
331 |
| - ] |
332 |
| - d_ex = (; |
333 |
| - tree=d_tree, |
334 |
| - metadata=(; |
335 |
| - _data=(; |
336 |
| - operators=NoTangent(), |
337 |
| - variable_names=NoTangent(), |
338 |
| - parameters=d_parameters, |
339 |
| - parameter_names=NoTangent(), |
340 |
| - ), |
341 |
| - ), |
342 |
| - ) |
343 |
| - return (NoTangent(), d_ex, d_X, NoTangent(), NoTangent()) |
344 |
| - end |
345 |
| - |
346 |
| - return (primal, complete), pullback |
347 |
| -end |
348 | 311 |
|
349 | 312 | function string_tree(
|
350 | 313 | ex::ParametricExpression,
|
|
0 commit comments