@@ -84,21 +84,19 @@ function generate_tgrad(sys::AbstractODESystem, dvs = unknowns(sys), ps = parame
84
84
simplify = false , kwargs... )
85
85
tgrad = calculate_tgrad (sys, simplify = simplify)
86
86
pre = get_preprocess_constants (tgrad)
87
- if ps isa Tuple
88
- return build_function (tgrad,
89
- dvs,
90
- ps... ,
91
- get_iv (sys);
92
- postprocess_fbody = pre,
93
- kwargs... )
87
+ p = if has_index_cache (sys)
88
+ reorder_parameters (get_index_cache (sys), ps)
89
+ elseif ps isa Tuple
90
+ ps
94
91
else
95
- return build_function (tgrad,
96
- dvs,
97
- ps,
98
- get_iv (sys);
99
- postprocess_fbody = pre,
100
- kwargs... )
92
+ (ps,)
101
93
end
94
+ return build_function (tgrad,
95
+ dvs,
96
+ p... ,
97
+ get_iv (sys);
98
+ postprocess_fbody = pre,
99
+ kwargs... )
102
100
end
103
101
104
102
function generate_jacobian (sys:: AbstractODESystem , dvs = unknowns (sys), ps = parameters (sys);
@@ -135,7 +133,12 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
135
133
@variables ˍ₋gamma
136
134
jac = ˍ₋gamma * jac_du + jac_u
137
135
pre = get_preprocess_constants (jac)
138
- return build_function (jac, derivatives, dvs, ps, ˍ₋gamma, get_iv (sys);
136
+ p = if has_index_cache (sys)
137
+ reorder_parameters (get_index_cache (sys), ps)
138
+ else
139
+ (ps,)
140
+ end
141
+ return build_function (jac, derivatives, dvs, p... , ˍ₋gamma, get_iv (sys);
139
142
postprocess_fbody = pre, kwargs... )
140
143
end
141
144
@@ -399,15 +402,25 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = u
399
402
build_explicit_observed_function (sys, obsvar)
400
403
end
401
404
if args === ()
402
- let obs = obs
403
- (u, p, t = Inf ) -> if ps isa Tuple
405
+ let obs = obs, ps_T = typeof (ps)
406
+ (u, p, t = Inf ) -> if p isa MTKParameters
407
+ obs (u, raw_vectors (p)... , t)
408
+ elseif ps_T <: Tuple
404
409
obs (u, p... , t)
405
410
else
406
411
obs (u, p, t)
407
412
end
408
413
end
409
414
else
410
- if ps isa Tuple
415
+ if args[2 ] isa MTKParameters
416
+ if length (args) == 2
417
+ u, p = args
418
+ obs (u, raw_vectors (p)... , Inf )
419
+ else
420
+ u, p, t = args
421
+ obs (u, raw_vectors (p)... , t)
422
+ end
423
+ elseif ps isa Tuple
411
424
if length (args) == 2
412
425
u, p = args
413
426
obs (u, p... , Inf )
@@ -437,16 +450,21 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = u
437
450
ps)
438
451
end
439
452
if args === ()
440
- let obs = obs
441
- (u, p, t) -> if ps isa Tuple
453
+ let obs = obs, ps_T = typeof (ps)
454
+ (u, p, t) -> if p isa MTKParameters
455
+ obs (u, raw_vectors (p)... , t)
456
+ elseif ps_T <: Tuple
442
457
obs (u, p... , t)
443
458
else
444
459
obs (u, p, t)
445
460
end
446
461
end
447
462
else
448
- if ps isa Tuple # split parameters
463
+ u, p, t = args
464
+ if p isa MTKParameters
449
465
u, p, t = args
466
+ obs (u, raw_vectors (p)... , t)
467
+ elseif ps isa Tuple # split parameters
450
468
obs (u, p... , t)
451
469
else
452
470
obs (args... )
@@ -518,7 +536,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
518
536
(drop_expr (@RuntimeGeneratedFunction (eval_module, ex)) for ex in f_gen) :
519
537
f_gen
520
538
f (du, u, p, t) = f_oop (du, u, p, t)
539
+ f (du, u, p:: MTKParameters , t) = f_oop (du, u, raw_vectors (p)... , t)
521
540
f (out, du, u, p, t) = f_iip (out, du, u, p, t)
541
+ f (out, du, u, p:: MTKParameters , t) = f_iip (out, du, u, raw_vectors (p)... , t)
522
542
523
543
if jac
524
544
jac_gen = generate_dae_jacobian (sys, dvs, ps;
@@ -530,8 +550,10 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
530
550
(drop_expr (@RuntimeGeneratedFunction (eval_module, ex)) for ex in jac_gen) :
531
551
jac_gen
532
552
_jac (du, u, p, ˍ₋gamma, t) = jac_oop (du, u, p, ˍ₋gamma, t)
553
+ _jac (du, u, p:: MTKParameters , ˍ₋gamma, t) = jac_oop (du, u, raw_vectors (p)... , ˍ₋gamma, t)
533
554
534
555
_jac (J, du, u, p, ˍ₋gamma, t) = jac_iip (J, du, u, p, ˍ₋gamma, t)
556
+ _jac (J, du, u, p:: MTKParameters , ˍ₋gamma, t) = jac_iip (J, du, u, raw_vectors (p)... , ˍ₋gamma, t)
535
557
else
536
558
_jac = nothing
537
559
end
@@ -544,10 +566,13 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
544
566
end
545
567
if args === ()
546
568
let obs = obs
547
- (u, p, t) -> obs (u, p, t)
569
+ fun (u, p, t) = obs (u, p, t)
570
+ fun (u, p:: MTKParameters , t) = obs (u, raw_vectors (p)... , t)
571
+ fun
548
572
end
549
573
else
550
- obs (args... )
574
+ u, p, t = args
575
+ p isa MTKParameters ? obs (u, raw_vectors (p)... , t) : obs (u, p, t)
551
576
end
552
577
end
553
578
end
@@ -591,7 +616,9 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
591
616
kwargs... )
592
617
f_oop, f_iip = (drop_expr (@RuntimeGeneratedFunction (eval_module, ex)) for ex in f_gen)
593
618
f (u, h, p, t) = f_oop (u, h, p, t)
619
+ f (u, h, p:: MTKParameters , t) = f_oop (u, h, raw_vectors (p)... , t)
594
620
f (du, u, h, p, t) = f_iip (du, u, h, p, t)
621
+ f (du, u, h, p:: MTKParameters , t) = f_iip (du, u, h, raw_vectors (p)... , t)
595
622
596
623
DDEFunction {iip} (f, sys = sys)
597
624
end
@@ -617,9 +644,13 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
617
644
isdde = true , kwargs... )
618
645
g_oop, g_iip = (drop_expr (@RuntimeGeneratedFunction (ex)) for ex in g_gen)
619
646
f (u, h, p, t) = f_oop (u, h, p, t)
647
+ f (u, h, p:: MTKParameters , t) = f_oop (u, h, raw_vectors (p)... , t)
620
648
f (du, u, h, p, t) = f_iip (du, u, h, p, t)
649
+ f (du, u, h, p:: MTKParameters , t) = f_iip (du, u, h, raw_vectors (p)... , t)
621
650
g (u, h, p, t) = g_oop (u, h, p, t)
651
+ g (u, h, p:: MTKParameters , t) = g_oop (u, h, raw_vectors (p)... , t)
622
652
g (du, u, h, p, t) = g_iip (du, u, h, p, t)
653
+ g (du, u, h, p:: MTKParameters , t) = g_iip (du, u, h, raw_vectors (p)... , t)
623
654
624
655
SDDEFunction {iip} (f, g, sys = sys)
625
656
end
0 commit comments