Skip to content

feat: support observed generation for array expressions #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@ version = "0.3.17"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
Pkg = "1"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.0.1"
Expand Down
1 change: 0 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module SymbolicIndexingInterface

import MacroTools
using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface
Expand Down
68 changes: 46 additions & 22 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,44 +69,68 @@ function independent_variable_symbols(sc::SymbolCache)
end
is_observed(sc::SymbolCache, sym) = false
is_observed(::SymbolCache, ::Expr) = true
is_observed(::SymbolCache, ::AbstractArray{Expr}) = true
is_observed(::SymbolCache, ::Tuple{Vararg{Expr}}) = true

struct ExpressionSearcher
declared::Set{Symbol}
fnbody::Expr
end

ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Expr(:block))

function (exs::ExpressionSearcher)(sys, expr::Expr)
for arg in expr.args
exs(sys, arg)
end
exs(sys, expr.head)
return nothing
end

function (exs::ExpressionSearcher)(sys, sym::Symbol)
sym in exs.declared && return
if is_variable(sys, sym)
idx = variable_index(sys, sym)
push!(exs.fnbody.args, :($sym = u[$idx]))
elseif is_parameter(sys, sym)
idx = parameter_index(sys, sym)
push!(exs.fnbody.args, :($sym = p[$idx]))
elseif is_independent_variable(sys, sym)
push!(exs.fnbody.args, :($sym = t))
end
push!(exs.declared, sym)
return nothing
end

(::ExpressionSearcher)(sys, sym) = nothing

function observed(sc::SymbolCache, expr::Expr)
let cache = Dict{Expr, Function}()
return get!(cache, expr) do
fnbody = Expr(:block)
declared = Set{Symbol}()
MacroTools.postwalk(expr) do sym
sym isa Symbol || return
sym in declared && return
if sc.variables !== nothing &&
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
push!(fnbody.args, :($sym = u[$idx]))
push!(declared, sym)
elseif sc.parameters !== nothing &&
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
push!(fnbody.args, :($sym = p[$idx]))
push!(declared, sym)
elseif sym === sc.independent_variables ||
sc.independent_variables isa Vector &&
sym == only(sc.independent_variables)
push!(fnbody.args, :($sym = t))
push!(declared, sym)
end
end
exs = ExpressionSearcher()
exs(sc, expr)
fnexpr = if is_time_dependent(sc)
:(function (u, p, t)
$fnbody
$(exs.fnbody)
return $expr
end)
else
:(function (u, p)
$fnbody
$(exs.fnbody)
return $expr
end)
end
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)
end
end
end
function observed(sc::SymbolCache, exprs::AbstractArray{Expr})
return observed(sc, :(reshape([$(exprs...)], $(size(exprs)))))
end
function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}})
return observed(sc, :(($(exprs...),)))
end

function is_time_dependent(sc::SymbolCache)
sc.independent_variables === nothing && return false
if symbolic_type(sc.independent_variables) == NotSymbolic()
Expand Down
11 changes: 11 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ obsfn = observed(sc, :(x + a + t))
obsfn2 = observed(sc, :(x + a + t))
@test obsfn === obsfn2

@test is_observed(sc, [:(x + a), :(a + t)])
obsfn3 = observed(sc, [:(x + a), :(a + t)])
@test obsfn3(ones(3), 2ones(2), 3.0) ≈ [3.0, 5.0]
@test is_observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)])
@test size(obsfn4(ones(3), 2ones(2), 3.0)) == (2, 2)
@test obsfn4(ones(3), 2ones(2), 3.0) ≈ [3.0 3.0; 2.0 4.0]
@test is_observed(sc, (:(x + a), :(y + b)))
obsfn5 = observed(sc, (:(x + a), :(y + b)))
@test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0))

sc = SymbolCache([:x, :y], [:a, :b])
@test !is_time_dependent(sc)
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
Expand Down