Skip to content

Commit 3ff8187

Browse files
test: add simple adjoint tests
1 parent 783f4b9 commit 3ff8187

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ SafeTestsets = "0.0.1"
1919
StaticArrays = "1.9"
2020
StaticArraysCore = "1.4"
2121
Test = "1"
22+
Zygote = "0.6.67"
2223
julia = "1.10"
2324

2425
[extras]
@@ -27,6 +28,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2728
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2829
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2930
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3032

3133
[targets]
32-
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]
34+
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Zygote"]

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ if GROUP == "All" || GROUP == "Core"
4545
@safetestset "BatchedInterface test" begin
4646
@time include("batched_interface_test.jl")
4747
end
48+
@safetestset "Simple Adjoints test" begin
49+
@time include("simple_adjoints_test.jl")
50+
end
4851
end
4952

5053
if GROUP == "All" || GROUP == "Downstream"

test/simple_adjoints_test.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using SymbolicIndexingInterface
2+
using Zygote
3+
4+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
5+
pstate = ProblemState(; u = rand(3), p = rand(3), t = rand())
6+
7+
getter = getu(sys, :x)
8+
@test Zygote.gradient(getter, pstate)[1].u == [1.0, 0.0, 0.0]
9+
10+
getter = getu(sys, [:x, :z])
11+
@test Zygote.gradient(sum getter, pstate)[1].u == [1.0, 0.0, 1.0]
12+
13+
getter = getu(sys, :a)
14+
@test Zygote.gradient(getter, pstate)[1].p == [1.0, 0.0, 0.0]
15+
16+
getter = getu(sys, [:a, :c])
17+
@test Zygote.gradient(sum getter, pstate)[1].p == [1.0, 0.0, 1.0]

0 commit comments

Comments
 (0)