Skip to content

Commit 3fc4124

Browse files
test: test callbacks without parameter splitting
1 parent c92a8b3 commit 3fc4124

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

test/symbolic_events.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ fsys = flatten(sys)
138138
@test isequal(ModelingToolkit.continuous_events(sys2)[2].eqs[], sys.x ~ 1)
139139

140140
sys = complete(sys)
141+
sys_nosplit = complete(sys; split = false)
141142
sys2 = complete(sys2)
142143
# Functions should be generated for root-finding equations
143144
prob = ODEProblem(sys, Pair[], (0.0, 2.0))
@@ -155,15 +156,22 @@ cond.rf_ip(out, [2], p0, t0)
155156
@test out[] 1 # signature is u,p,t
156157

157158
prob = ODEProblem(sys, Pair[], (0.0, 2.0))
159+
prob_nosplit = ODEProblem(sys_nosplit, Pair[], (0.0, 2.0))
158160
sol = solve(prob, Tsit5())
161+
sol_nosplit = solve(prob_nosplit, Tsit5())
159162
@test minimum(t -> abs(t - 1), sol.t) < 1e-10 # test that the solver stepped at the root
163+
@test minimum(t -> abs(t - 1), sol_nosplit.t) < 1e-10 # test that the solver stepped at the root
160164

161165
# Test that a user provided callback is respected
162166
test_callback = DiscreteCallback(x -> x, x -> x)
163167
prob = ODEProblem(sys, Pair[], (0.0, 2.0), callback = test_callback)
168+
prob_nosplit = ODEProblem(sys_nosplit, Pair[], (0.0, 2.0), callback = test_callback)
164169
cbs = get_callback(prob)
170+
cbs_nosplit = get_callback(prob_nosplit)
165171
@test cbs isa CallbackSet
166172
@test cbs.discrete_callbacks[1] == test_callback
173+
@test cbs_nosplit isa CallbackSet
174+
@test cbs_nosplit.discrete_callbacks[1] == test_callback
167175

168176
prob = ODEProblem(sys2, Pair[], (0.0, 3.0))
169177
cb = get_callback(prob)
@@ -234,9 +242,11 @@ continuous_events = [[x ~ 0] => [vx ~ -vx]
234242
D(vy) ~ -0.01vy], t; continuous_events)
235243

236244
ball = structural_simplify(ball)
245+
ball_nosplit = structural_simplify(ball; split = false)
237246

238247
tspan = (0.0, 5.0)
239248
prob = ODEProblem(ball, Pair[], tspan)
249+
prob_nosplit = ODEProblem(ball_nosplit, Pair[], tspan)
240250

241251
cb = get_callback(prob)
242252
@test cb isa ModelingToolkit.DiffEqCallbacks.VectorContinuousCallback
@@ -250,9 +260,13 @@ cond.rf_ip(out, [0, 0, 0, 0], p0, t0)
250260
@test out [0, 1.5, -1.5]
251261

252262
sol = solve(prob, Tsit5())
263+
sol_nosplit = solve(prob_nosplit, Tsit5())
253264
@test 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
254265
@test minimum(sol[y]) -1.5 # check wall conditions
255266
@test maximum(sol[y]) 1.5 # check wall conditions
267+
@test 0 <= minimum(sol_nosplit[x]) <= 1e-10 # the ball never went through the floor but got very close
268+
@test minimum(sol_nosplit[y]) -1.5 # check wall conditions
269+
@test maximum(sol_nosplit[y]) 1.5 # check wall conditions
256270

257271
# tv = sort([LinRange(0, 5, 200); sol.t])
258272
# plot(sol(tv)[y], sol(tv)[x], line_z=tv)
@@ -270,13 +284,18 @@ continuous_events = [
270284
D(vx) ~ -1
271285
D(vy) ~ 0], t; continuous_events)
272286

287+
ball_nosplit = structural_simplify(ball)
273288
ball = structural_simplify(ball)
274289

275290
tspan = (0.0, 5.0)
276291
prob = ODEProblem(ball, Pair[], tspan)
292+
prob_nosplit = ODEProblem(ball_nosplit, Pair[], tspan)
277293
sol = solve(prob, Tsit5())
294+
sol_nosplit = solve(prob_nosplit, Tsit5())
278295
@test 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
279296
@test -minimum(sol[y]) maximum(sol[y]) sqrt(2) # the ball will never go further than √2 in either direction (gravity was changed to 1 to get this particular number)
297+
@test 0 <= minimum(sol_nosplit[x]) <= 1e-10 # the ball never went through the floor but got very close
298+
@test -minimum(sol_nosplit[y]) maximum(sol_nosplit[y]) sqrt(2) # the ball will never go further than √2 in either direction (gravity was changed to 1 to get this particular number)
280299

281300
# tv = sort([LinRange(0, 5, 200); sol.t])
282301
# plot(sol(tv)[y], sol(tv)[x], line_z=tv)

0 commit comments

Comments
 (0)