Skip to content

Commit 5912dd7

Browse files
Merge pull request #469 from SciML/myb/register_der
Document registering the derivative of a function
2 parents 8b2a3c8 + faf5641 commit 5912dd7

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/function_registration.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,31 @@ ModelingToolkit IR. Example:
1010
```
1111
1212
registers `f` as a possible two-argument function.
13+
14+
You may also want to tell ModelingToolkit the derivative of the registered
15+
function. Here is an example to do it
16+
17+
```julia
18+
julia> using ModelingToolkit
19+
20+
julia> foo(x, y) = sin(x) * cos(y)
21+
foo (generic function with 1 method)
22+
23+
julia> @parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
24+
25+
julia> @register foo(x, y)
26+
foo (generic function with 4 methods)
27+
28+
julia> foo(x, y)
29+
foo(x(t), y(t))
30+
31+
julia> ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
32+
33+
julia> ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
34+
35+
julia> isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))
36+
true
37+
```
1338
"""
1439
macro register(sig)
1540
splitsig = splitdef(:($sig = nothing))
@@ -116,4 +141,4 @@ function inject_registered_module_functions(expr)
116141

117142
return x
118143
end
119-
end
144+
end

test/function_registration.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,15 @@ sys = ODESystem([eq], t, [u], [x])
6262
fun = ODEFunction(sys)
6363

6464
@test fun([0.5], [7.0], 0.) == [74.0]
65+
66+
# derivative
67+
foo(x, y) = sin(x) * cos(y)
68+
@parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
69+
@register foo(x, y)
70+
expr = foo(x, y)
71+
@test expr.op === foo
72+
@test expr.args[1] === x
73+
@test expr.args[2] === y
74+
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
75+
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
76+
@test isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))

0 commit comments

Comments
 (0)