Skip to content

Commit ea8e020

Browse files
authored
Merge pull request #89 from SymbolicML/generalize-node-tangents
Create `extract_gradient`
2 parents f1ceb88 + ed5deaa commit ea8e020

8 files changed

+86
-16
lines changed

src/ChainRules.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
module ChainRulesModule
22

33
using ChainRulesCore:
4-
ChainRulesCore, AbstractTangent, NoTangent, ZeroTangent, Tangent, @thunk, canonicalize
4+
ChainRulesCore as CRC,
5+
AbstractTangent,
6+
NoTangent,
7+
ZeroTangent,
8+
Tangent,
9+
@thunk,
10+
canonicalize
511
using ..OperatorEnumModule: OperatorEnum
612
using ..NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
713
using ..EvaluateModule: eval_tree_array
@@ -11,6 +17,9 @@ struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: Abstra
1117
tree::N
1218
gradient::A
1319
end
20+
function extract_gradient(gradient::NodeTangent, ::AbstractExpressionNode)
21+
return gradient.gradient
22+
end
1423
function Base.:+(a::NodeTangent, b::NodeTangent)
1524
# @assert a.tree == b.tree
1625
return NodeTangent(a.tree, a.gradient + b.gradient)
@@ -19,7 +28,7 @@ Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
1928
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
2029
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()
2130

22-
function ChainRulesCore.rrule(
31+
function CRC.rrule(
2332
::typeof(eval_tree_array),
2433
tree::AbstractExpressionNode,
2534
X::AbstractMatrix,

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ import .NodeModule:
6161
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
6262
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
6363
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
64-
@reexport import .ChainRulesModule: NodeTangent
64+
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
6565
@reexport import .SimplifyModule: combine_operators, simplify_tree!
6666
@reexport import .EvaluationHelpersModule
6767
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

src/Expression.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DispatchDoctor: @unstable
55
using ..NodeModule: AbstractExpressionNode, Node
66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..UtilsModule: Undefined
8+
using ..ChainRulesModule: NodeTangent
89

910
import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof
1011
import ..NodeUtilsModule:
@@ -16,6 +17,7 @@ import ..NodeUtilsModule:
1617
has_constants,
1718
get_constants,
1819
set_constants!
20+
import ..ChainRulesModule: extract_gradient
1921

2022
"""A wrapper for a named tuple to avoid piracy."""
2123
struct Metadata{NT<:NamedTuple}
@@ -140,6 +142,12 @@ end
140142
function set_constants!(ex::AbstractExpression{T}, constants, refs) where {T}
141143
return error("`set_constants!` function must be implemented for $(typeof(ex)) types.")
142144
end
145+
function extract_gradient(gradient, ex::AbstractExpression)
146+
# Should match `get_constants`
147+
return error(
148+
"`extract_gradient` function must be implemented for $(typeof(ex)) types with $(typeof(gradient)) gradient.",
149+
)
150+
end
143151
function get_contents(ex::AbstractExpression)
144152
return error("`get_contents` function must be implemented for $(typeof(ex)) types.")
145153
end
@@ -263,6 +271,12 @@ end
263271
function set_constants!(ex::Expression{T}, constants, refs) where {T}
264272
return set_constants!(get_tree(ex), constants, refs)
265273
end
274+
function extract_gradient(
275+
gradient::@NamedTuple{tree::NT, metadata::Nothing}, ex::Expression{T,N}
276+
) where {T,N<:AbstractExpressionNode{T},NT<:NodeTangent{T,N}}
277+
# TODO: This messy gradient type is produced by ChainRules. There is probably a better way to do this.
278+
return extract_gradient(gradient.tree, get_tree(ex))
279+
end
266280

267281
import ..StringsModule: string_tree, print_tree
268282

src/Interfaces.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,16 @@ ei_components = (
153153
index_constants = "indexes constants in the expression tree" => _check_index_constants,
154154
has_operators = "checks if the expression has operators" => _check_has_operators,
155155
has_constants = "checks if the expression has constants" => _check_has_constants,
156-
get_constants = "gets constants from the expression tree" => _check_get_constants,
157-
set_constants! = "sets constants in the expression tree" => _check_set_constants!,
156+
get_constants = ("gets constants from the expression tree, returning a tuple of: " *
157+
"(1) a flat vector of the constants, and (2) an reference object that " *
158+
"can be used by `set_constants!` to efficiently set them back") => _check_get_constants,
159+
set_constants! = ("sets constants in the expression tree, given: " *
160+
"(1) a flat vector of constants, (2) the expression, and " *
161+
"(3) the reference object produced by `get_constants`") => _check_set_constants!,
158162
string_tree = "returns a string representation of the expression tree" => _check_string_tree,
159163
default_node_type = "returns the default node type for the expression" => _check_default_node,
160164
constructorof = "gets the constructor function for a type" => _check_constructorof,
161-
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
165+
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce,
162166
)
163167
)
164168
ei_description = (
@@ -332,10 +336,14 @@ ni_components = (
332336
count_constants = "counts the number of constants" => _check_count_constants,
333337
filter_map = "applies a filter and map function to the tree" => _check_filter_map,
334338
has_constants = "checks if the tree has constants" => _check_has_constants,
335-
get_constants = "gets constants from the tree" => _check_get_constants,
336-
set_constants! = "sets constants in the tree" => _check_set_constants!,
339+
get_constants = ("gets constants from the tree, returning a tuple of: " *
340+
"(1) a flat vector of the constants, and (2) a reference object that " *
341+
"can be used by `set_constants!` to efficiently set them back") => _check_get_constants,
342+
set_constants! = ("sets constants in the tree, given: " *
343+
"(1) a flat vector of constants, (2) the tree, and " *
344+
"(3) the reference object produced by `get_constants`") => _check_set_constants!,
337345
index_constants = "indexes constants in the tree" => _check_index_constants,
338-
has_operators = "checks if the tree has operators" => _check_has_operators
346+
has_operators = "checks if the tree has operators" => _check_has_operators,
339347
)
340348
)
341349

@@ -372,6 +380,8 @@ ni_description = (
372380

373381
#! format: on
374382

375-
# TODO: Create an interface for evaluation
383+
# TODO: Create an interface for evaluation and `extract_gradient`
384+
# extract_gradient = ("given a Zygote-computed gradient with respect to the tree constants, " *
385+
# "extracts a flat vector in the same order as `get_constants`") => _check_extract_gradient,
376386

377387
end

src/ParametricExpression.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ParametricExpressionModule
22

33
using DispatchDoctor: @stable, @unstable
4-
using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
4+
using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
@@ -20,6 +20,7 @@ import ..StringsModule: string_tree
2020
import ..EvaluateModule: eval_tree_array
2121
import ..EvaluateDerivativeModule: eval_grad_tree_array
2222
import ..EvaluationHelpersModule: _grad_evaluator
23+
import ..ChainRulesModule: extract_gradient
2324
import ..ExpressionModule:
2425
get_contents,
2526
get_metadata,
@@ -207,7 +208,7 @@ has_constants(ex::ParametricExpression) = _interface_error()
207208
has_operators(ex::ParametricExpression) = has_operators(get_tree(ex))
208209
function get_constants(ex::ParametricExpression{T}) where {T}
209210
constants, constant_refs = get_constants(get_tree(ex))
210-
parameters = ex.metadata.parameters
211+
parameters = get_metadata(ex).parameters
211212
flat_parameters = parameters[:]
212213
num_constants = length(constants)
213214
num_parameters = length(flat_parameters)
@@ -218,9 +219,27 @@ function set_constants!(ex::ParametricExpression{T}, x, refs) where {T}
218219
# First, set the usual constants
219220
set_constants!(get_tree(ex), @view(x[1:(refs.num_constants)]), refs.constant_refs)
220221
# Then, copy in the parameters
221-
ex.metadata.parameters[:] .= @view(x[(refs.num_constants + 1):end])
222+
get_metadata(ex).parameters[:] .= @view(x[(refs.num_constants + 1):end])
222223
return ex
223224
end
225+
function extract_gradient(
226+
gradient::@NamedTuple{
227+
tree::NT,
228+
metadata::@NamedTuple{
229+
_data::@NamedTuple{
230+
operators::Nothing,
231+
variable_names::Nothing,
232+
parameters::PARAM,
233+
parameter_names::Nothing,
234+
}
235+
}
236+
},
237+
ex::ParametricExpression{T,N},
238+
) where {T,N<:ParametricNode{T},NT<:NodeTangent{T,N},PARAM<:AbstractMatrix{T}}
239+
d_constants = extract_gradient(gradient.tree, get_tree(ex))
240+
d_params = gradient.metadata._data.parameters[:]
241+
return vcat(d_constants, d_params) # Same shape as `get_constants`
242+
end
224243

225244
function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
226245
num_params = UInt16(size(ex.metadata.parameters, 1))
@@ -238,9 +257,7 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
238257
Node{T},
239258
)
240259
end
241-
function ChainRulesCore.rrule(
242-
::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}
243-
) where {T}
260+
function CRC.rrule(::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}) where {T}
244261
tree = get_contents(ex)
245262
primal = convert(Node, ex)
246263
pullback = let tree = tree

test/test_expressions.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ end
7676
end
7777
end
7878

79+
@testitem "Can also get derivatives of expression itself" begin
80+
using DynamicExpressions
81+
using Zygote: Zygote
82+
using DifferentiationInterface: AutoZygote, gradient
83+
84+
ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"])
85+
d_ex = gradient(AutoZygote(), ex) do ex
86+
sum(ex(ones(1, 5)))
87+
end
88+
@test d_ex isa NamedTuple
89+
@test extract_gradient(d_ex, ex) [5.0]
90+
end
91+
7992
@testitem "Expression simplification" begin
8093
using DynamicExpressions
8194

test/test_multi_expression.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
@test_throws "`set_constants!` function must be implemented for" set_constants!(
6868
multi_ex, nothing, nothing
6969
)
70+
@test_throws "`extract_gradient` function must be implemented for" extract_gradient(
71+
nothing, multi_ex
72+
)
7073
end
7174

7275
tree_factory(f::F, trees) where {F} = f(; trees...)

test/test_parametric_expression.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,4 +316,8 @@ end
316316
@test grad.tree.gradient true_grad[3]
317317
# Gradient w.r.t. the parameters:
318318
@test grad.metadata._data.parameters true_grad[2]
319+
320+
# Gradient extractor
321+
@test extract_gradient(grad, ex) vcat(true_grad[3], true_grad[2][:])
322+
@test axes(extract_gradient(grad, ex)) == axes(first(get_constants(ex)))
319323
end

0 commit comments

Comments
 (0)