Skip to content

Commit c6bf421

Browse files
Merge pull request #70 from SciML/as/array-observed
feat: support observed generation for array expressions
2 parents 5a78f02 + 3cae58c commit c6bf421

File tree

4 files changed

+57
-25
lines changed

4 files changed

+57
-25
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@ version = "0.3.18"
66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
109
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1110
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1211

1312
[compat]
1413
Accessors = "0.1.36"
1514
Aqua = "0.8"
1615
ArrayInterface = "7.9"
17-
MacroTools = "0.5.13"
1816
Pkg = "1"
1917
RuntimeGeneratedFunctions = "0.5.12"
2018
SafeTestsets = "0.0.1"

src/SymbolicIndexingInterface.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module SymbolicIndexingInterface
22

3-
import MacroTools
43
using RuntimeGeneratedFunctions
54
import StaticArraysCore: MArray, similar_type
65
import ArrayInterface

src/symbol_cache.jl

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,44 +69,68 @@ function independent_variable_symbols(sc::SymbolCache)
6969
end
7070
is_observed(sc::SymbolCache, sym) = false
7171
is_observed(::SymbolCache, ::Expr) = true
72+
is_observed(::SymbolCache, ::AbstractArray{Expr}) = true
73+
is_observed(::SymbolCache, ::Tuple{Vararg{Expr}}) = true
74+
75+
struct ExpressionSearcher
76+
declared::Set{Symbol}
77+
fnbody::Expr
78+
end
79+
80+
ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Expr(:block))
81+
82+
function (exs::ExpressionSearcher)(sys, expr::Expr)
83+
for arg in expr.args
84+
exs(sys, arg)
85+
end
86+
exs(sys, expr.head)
87+
return nothing
88+
end
89+
90+
function (exs::ExpressionSearcher)(sys, sym::Symbol)
91+
sym in exs.declared && return
92+
if is_variable(sys, sym)
93+
idx = variable_index(sys, sym)
94+
push!(exs.fnbody.args, :($sym = u[$idx]))
95+
elseif is_parameter(sys, sym)
96+
idx = parameter_index(sys, sym)
97+
push!(exs.fnbody.args, :($sym = p[$idx]))
98+
elseif is_independent_variable(sys, sym)
99+
push!(exs.fnbody.args, :($sym = t))
100+
end
101+
push!(exs.declared, sym)
102+
return nothing
103+
end
104+
105+
(::ExpressionSearcher)(sys, sym) = nothing
106+
72107
function observed(sc::SymbolCache, expr::Expr)
73108
let cache = Dict{Expr, Function}()
74109
return get!(cache, expr) do
75-
fnbody = Expr(:block)
76-
declared = Set{Symbol}()
77-
MacroTools.postwalk(expr) do sym
78-
sym isa Symbol || return
79-
sym in declared && return
80-
if sc.variables !== nothing &&
81-
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
82-
push!(fnbody.args, :($sym = u[$idx]))
83-
push!(declared, sym)
84-
elseif sc.parameters !== nothing &&
85-
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
86-
push!(fnbody.args, :($sym = p[$idx]))
87-
push!(declared, sym)
88-
elseif sym === sc.independent_variables ||
89-
sc.independent_variables isa Vector &&
90-
sym == only(sc.independent_variables)
91-
push!(fnbody.args, :($sym = t))
92-
push!(declared, sym)
93-
end
94-
end
110+
exs = ExpressionSearcher()
111+
exs(sc, expr)
95112
fnexpr = if is_time_dependent(sc)
96113
:(function (u, p, t)
97-
$fnbody
114+
$(exs.fnbody)
98115
return $expr
99116
end)
100117
else
101118
:(function (u, p)
102-
$fnbody
119+
$(exs.fnbody)
103120
return $expr
104121
end)
105122
end
106123
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)
107124
end
108125
end
109126
end
127+
function observed(sc::SymbolCache, exprs::AbstractArray{Expr})
128+
return observed(sc, :(reshape([$(exprs...)], $(size(exprs)))))
129+
end
130+
function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}})
131+
return observed(sc, :(($(exprs...),)))
132+
end
133+
110134
function is_time_dependent(sc::SymbolCache)
111135
sc.independent_variables === nothing && return false
112136
if symbolic_type(sc.independent_variables) == NotSymbolic()

test/symbol_cache_test.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ obsfn = observed(sc, :(x + a + t))
3434
obsfn2 = observed(sc, :(x + a + t))
3535
@test obsfn === obsfn2
3636

37+
@test is_observed(sc, [:(x + a), :(a + t)])
38+
obsfn3 = observed(sc, [:(x + a), :(a + t)])
39+
@test obsfn3(ones(3), 2ones(2), 3.0) [3.0, 5.0]
40+
@test is_observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
41+
obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
42+
@test size(obsfn4(ones(3), 2ones(2), 3.0)) == (2, 2)
43+
@test obsfn4(ones(3), 2ones(2), 3.0) [3.0 3.0; 2.0 4.0]
44+
@test is_observed(sc, (:(x + a), :(y + b)))
45+
obsfn5 = observed(sc, (:(x + a), :(y + b)))
46+
@test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0))
47+
3748
sc = SymbolCache([:x, :y], [:a, :b])
3849
@test !is_time_dependent(sc)
3950
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]

0 commit comments

Comments
 (0)