Skip to content

Commit 2219c0b

Browse files
committed
test: gradients of parametric expression
1 parent bcfc360 commit 2219c0b

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

src/ParametricExpression.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ function ChainRulesCore.rrule(
293293
parameters = ex.metadata.parameters
294294
num_params = size(parameters, 1)
295295
num_classes = size(parameters, 2)
296+
_operators = get_operators(ex, operators)
296297
indexed_parameters = [
297298
parameters[i_parameter, classes[i_row]] for
298299
i_parameter in eachindex(axes(parameters, 1)), i_row in eachindex(classes)
@@ -302,10 +303,10 @@ function ChainRulesCore.rrule(
302303
regular_tree = convert(Node, ex)
303304

304305
_, gradient_tree, complete1 = eval_grad_tree_array(
305-
regular_tree, params_and_X, operators; variable=Val(false)
306+
regular_tree, params_and_X, _operators; variable=Val(false)
306307
)
307308
_, gradient_params_and_X, complete2 = eval_grad_tree_array(
308-
regular_tree, params_and_X, operators; variable=Val(true)
309+
regular_tree, params_and_X, _operators; variable=Val(true)
309310
)
310311

311312
if !complete1
@@ -331,13 +332,15 @@ function ChainRulesCore.rrule(
331332
d_ex = (;
332333
tree=d_tree,
333334
metadata=(;
334-
operators=NoTangent(),
335-
variable_names=NoTangent(),
336-
parameters=d_parameters,
337-
parameter_names=NoTangent(),
335+
_data=(;
336+
operators=NoTangent(),
337+
variable_names=NoTangent(),
338+
parameters=d_parameters,
339+
parameter_names=NoTangent(),
340+
),
338341
),
339342
)
340-
return (NoTangent(), d_ex, copy(d_X), NoTangent(), NoTangent())
343+
return (NoTangent(), d_ex, d_X, NoTangent(), NoTangent())
341344
end
342345

343346
return (primal, complete), pullback

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
44
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
55
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
6+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
67
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
78
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

test/test_parametric_expression.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,37 @@ end
266266
)
267267
end
268268
end
269+
270+
@testitem "Parametric expression derivatives" begin
271+
using DynamicExpressions
272+
using Zygote: Zygote
273+
using Random: MersenneTwister
274+
using DifferentiationInterface: value_and_gradient, AutoZygote
275+
276+
rng = MersenneTwister(0)
277+
X = rand(rng, 2, 32)
278+
classes = rand(rng, 1:2, 32)
279+
y = @. X[1, :] * X[1, :] - cos(2.6 * X[2, :]) + classes
280+
281+
operators = OperatorEnum(; unary_operators=[cos], binary_operators=[+, *, -])
282+
283+
ex = @parse_expression(
284+
x * x - cos(2.5 * y) + p1,
285+
operators = operators,
286+
expression_type = ParametricExpression,
287+
variable_names = ["x", "y"],
288+
extra_metadata = (parameter_names=["p1"], parameters=[0.5 0.2])
289+
)
290+
f = let operators = operators, X = X, classes = classes, y = y
291+
ex -> sum(abs2, ex(X, classes) .- y)
292+
end
293+
@test f(ex) isa Float64
294+
(val, grad) = value_and_gradient(f, AutoZygote(), ex)
295+
296+
@test val isa Float64
297+
@test grad isa NamedTuple
298+
@test grad.tree isa DynamicExpressions.ChainRulesModule.NodeTangent{
299+
Float64,ParametricNode{Float64},Vector{Float64}
300+
}
301+
@test grad.metadata._data.parameters isa Matrix{Float64}
302+
end

0 commit comments

Comments
 (0)