Skip to content

Commit 49713d6

Browse files
committed
test: verify gradients of parametric expressions
1 parent 2711934 commit 49713d6

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

test/test_parametric_expression.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,27 @@ end
275275

276276
rng = MersenneTwister(0)
277277
X = rand(rng, 2, 32)
278+
true_params = [0.5 2.0]
279+
init_params = [0.1 0.2]
278280
classes = rand(rng, 1:2, 32)
279-
y = @. X[1, :] * X[1, :] - cos(2.6 * X[2, :]) + classes
281+
y = [X[1, i] * X[1, i] - cos(2.6 * X[2, i]) + true_params[1, classes[i]] for i in 1:32]
280282

281-
operators = OperatorEnum(; unary_operators=[cos], binary_operators=[+, *, -])
283+
(true_val, true_grad) =
284+
value_and_gradient(AutoZygote(), (X, init_params, [2.5])) do (X, params, c)
285+
pred = [
286+
X[1, i] * X[1, i] - cos(c[1] * X[2, i]) + params[1, classes[i]] for
287+
i in 1:32
288+
]
289+
sum(abs2, pred .- y)
290+
end
282291

292+
operators = OperatorEnum(; unary_operators=[cos], binary_operators=[+, *, -])
283293
ex = @parse_expression(
284294
x * x - cos(2.5 * y) + p1,
285295
operators = operators,
286296
expression_type = ParametricExpression,
287297
variable_names = ["x", "y"],
288-
extra_metadata = (parameter_names=["p1"], parameters=[0.5 0.2])
298+
extra_metadata = (parameter_names=["p1"], parameters=init_params)
289299
)
290300
f = let operators = operators, X = X, classes = classes, y = y
291301
ex -> sum(abs2, ex(X, classes) .- y)
@@ -299,4 +309,11 @@ end
299309
Float64,ParametricNode{Float64},Vector{Float64}
300310
}
301311
@test grad.metadata._data.parameters isa Matrix{Float64}
312+
313+
# Loss value:
314+
@test val true_val
315+
# Gradient w.r.t. the constant:
316+
@test grad.tree.gradient true_grad[3]
317+
# Gradient w.r.t. the parameters:
318+
@test grad.metadata._data.parameters true_grad[2]
302319
end

0 commit comments

Comments
 (0)