Skip to content

Commit faf5641

Browse files
committed
Fix docs and add tests
1 parent 2224d14 commit faf5641

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

src/function_registration.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,28 @@ ModelingToolkit IR. Example:
1212
registers `f` as a possible two-argument function.
1313
1414
You may also want to tell ModelingToolkit the derivative of the registered
15-
function. You can achieve this by
15+
function. Here is an example to do it
1616
1717
```julia
18+
julia> using ModelingToolkit
19+
1820
julia> foo(x, y) = sin(x) * cos(y)
1921
foo (generic function with 1 method)
2022
21-
julia> ModelingToolkit.derivative(::typeof(foo), x, y, ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
23+
julia> @parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
2224
23-
julia> ModelingToolkit.derivative(::typeof(foo), x, y, ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
25+
julia> @register foo(x, y)
26+
foo (generic function with 4 methods)
2427
25-
julia> @parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
28+
julia> foo(x, y)
29+
foo(x(t), y(t))
30+
31+
julia> ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
32+
33+
julia> ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
2634
27-
julia> expand_derivatives(D(foo(x, y) * z))
28-
z(t) * (derivative(x(t), t) * cos(x(t)) * cos(y(t)) + -1 * sin(x(t)) * derivative(y(t), t) * sin(y(t))) + sin(x(t)) * cos(y(t)) * derivative(z(t), t)
35+
julia> isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))
36+
true
2937
```
3038
"""
3139
macro register(sig)

test/function_registration.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,15 @@ sys = ODESystem([eq], t, [u], [x])
6262
fun = ODEFunction(sys)
6363

6464
@test fun([0.5], [7.0], 0.) == [74.0]
65+
66+
# derivative
67+
foo(x, y) = sin(x) * cos(y)
68+
@parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
69+
@register foo(x, y)
70+
expr = foo(x, y)
71+
@test expr.op === foo
72+
@test expr.args[1] === x
73+
@test expr.args[2] === y
74+
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
75+
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
76+
@test isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))

0 commit comments

Comments
 (0)