Skip to content

Commit ba0f841

Browse files
authored
Merge pull request #88 from SymbolicML/fix-parametric-gradients
Fix gradients of parametric expressions
2 parents 9e95f05 + a2f5673 commit ba0f841

File tree

4 files changed

+88
-3
lines changed

4 files changed

+88
-3
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: Abstra
1212
gradient::A
1313
end
1414
function Base.:+(a::NodeTangent, b::NodeTangent)
15-
@assert a.tree == b.tree
15+
# @assert a.tree == b.tree
1616
return NodeTangent(a.tree, a.gradient + b.gradient)
1717
end
1818
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)

src/ParametricExpression.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module ParametricExpressionModule
22

33
using DispatchDoctor: @stable, @unstable
4+
using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
45

5-
using ..OperatorEnumModule: AbstractOperatorEnum
6+
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
67
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
78
using ..ExpressionModule: AbstractExpression, Metadata
9+
using ..ChainRulesModule: NodeTangent
810

911
import ..NodeModule: constructorof, preserve_sharing, leaf_copy, leaf_hash, leaf_equal
1012
import ..NodeUtilsModule:
@@ -236,6 +238,36 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
236238
Node{T},
237239
)
238240
end
241+
function ChainRulesCore.rrule(
242+
::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}
243+
) where {T}
244+
tree = get_contents(ex)
245+
primal = convert(Node, ex)
246+
pullback = let tree = tree
247+
d_primal -> let
248+
# ^The exact same tangent with respect to constants, so we can just take it.
249+
d_ex = @thunk(
250+
let
251+
parametric_node_tangent = NodeTangent(tree, d_primal.gradient)
252+
(;
253+
tree=parametric_node_tangent,
254+
metadata=(;
255+
_data=(;
256+
operators=NoTangent(),
257+
variable_names=NoTangent(),
258+
parameters=NoTangent(),
259+
parameter_names=NoTangent(),
260+
)
261+
),
262+
)
263+
end
264+
)
265+
(NoTangent(), NoTangent(), d_ex)
266+
end
267+
end
268+
return primal, pullback
269+
end
270+
239271
#! format: off
240272
function (ex::ParametricExpression)(X::AbstractMatrix, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...)
241273
return eval_tree_array(ex, X, operators; kws...) # Will error
@@ -250,7 +282,7 @@ function (ex::ParametricExpression)(
250282
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
251283
kws...,
252284
) where {T}
253-
(output, flag) = eval_tree_array(ex, X, classes, operators; kws...) # Will error
285+
(output, flag) = eval_tree_array(ex, X, classes, operators; kws...)
254286
if !flag
255287
output .= NaN
256288
end
@@ -276,6 +308,7 @@ function eval_tree_array(
276308
regular_tree = convert(Node, ex)
277309
return eval_tree_array(regular_tree, params_and_X, get_operators(ex, operators); kws...)
278310
end
311+
279312
function string_tree(
280313
ex::ParametricExpression,
281314
operators::Union{AbstractOperatorEnum,Nothing}=nothing;

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
44
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
55
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
6+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
67
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
78
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

test/test_parametric_expression.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,54 @@ end
266266
)
267267
end
268268
end
269+
270+
@testitem "Parametric expression derivatives" begin
271+
using DynamicExpressions
272+
using Zygote: Zygote
273+
using Random: MersenneTwister
274+
using DifferentiationInterface: value_and_gradient, AutoZygote
275+
276+
rng = MersenneTwister(0)
277+
X = rand(rng, 2, 32)
278+
true_params = [0.5 2.0]
279+
init_params = [0.1 0.2]
280+
classes = rand(rng, 1:2, 32)
281+
y = [X[1, i] * X[1, i] - cos(2.6 * X[2, i]) + true_params[1, classes[i]] for i in 1:32]
282+
283+
(true_val, true_grad) =
284+
value_and_gradient(AutoZygote(), (X, init_params, [2.5])) do (X, params, c)
285+
pred = [
286+
X[1, i] * X[1, i] - cos(c[1] * X[2, i]) + params[1, classes[i]] for
287+
i in 1:32
288+
]
289+
sum(abs2, pred .- y)
290+
end
291+
292+
operators = OperatorEnum(; unary_operators=[cos], binary_operators=[+, *, -])
293+
ex = @parse_expression(
294+
x * x - cos(2.5 * y) + p1,
295+
operators = operators,
296+
expression_type = ParametricExpression,
297+
variable_names = ["x", "y"],
298+
extra_metadata = (parameter_names=["p1"], parameters=init_params)
299+
)
300+
f = let operators = operators, X = X, classes = classes, y = y
301+
ex -> sum(abs2, ex(X, classes) .- y)
302+
end
303+
@test f(ex) isa Float64
304+
(val, grad) = value_and_gradient(f, AutoZygote(), ex)
305+
306+
@test val isa Float64
307+
@test grad isa NamedTuple
308+
@test grad.tree isa DynamicExpressions.ChainRulesModule.NodeTangent{
309+
Float64,ParametricNode{Float64},Vector{Float64}
310+
}
311+
@test grad.metadata._data.parameters isa Matrix{Float64}
312+
313+
# Loss value:
314+
@test val true_val
315+
# Gradient w.r.t. the constant:
316+
@test grad.tree.gradient true_grad[3]
317+
# Gradient w.r.t. the parameters:
318+
@test grad.metadata._data.parameters true_grad[2]
319+
end

0 commit comments

Comments
 (0)