Skip to content

Commit c0cf44a

Browse files
committed
Add DynamicQuantities support
1 parent ec3ac52 commit c0cf44a

File tree

5 files changed

+432
-20
lines changed

5 files changed

+432
-20
lines changed

src/systems/unit_check.jl

Lines changed: 250 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import DynamicQuantities, Unitful
22
const DQ = DynamicQuantities
33

4+
#For dispatching get_unit
5+
const Conditional = Union{typeof(ifelse), typeof(IfElse.ifelse)}
6+
const Comparison = Union{typeof.([==, !=, , <, <=, , >, >=, ])...}
7+
48
struct ValidationError <: Exception
59
message::String
610
end
@@ -42,20 +46,260 @@ function __get_unit_type(vs′...)
4246
return nothing
4347
end
4448

45-
function check_units(::Val{:DynamicQuantities}, eqs...)
46-
validate(eqs...) ||
47-
throw(ValidationError("Some equations had invalid units. See warnings for details."))
48-
end
49-
50-
function screen_units(result)
49+
function screen_unit(result)
5150
if result isa DQ.AbstractQuantity
5251
d = DQ.dimension(result)
5352
if d isa DQ.Dimensions
53+
if result != oneunit(result)
54+
throw(ValidationError("$result uses non SI unit. Please use SI unit only."))
55+
end
5456
return result
5557
elseif d isa DQ.SymbolicDimensions
5658
throw(ValidationError("$result uses SymbolicDimensions, please use `u\"m\"` to instantiate SI unit only."))
5759
else
5860
throw(ValidationError("$result doesn't use SI unit, please use `u\"m\"` to instantiate SI unit only."))
5961
end
62+
else
63+
throw(ValidationError("$result doesn't have any unit."))
64+
end
65+
end
66+
67+
const unitless = DQ.Quantity(1.0)
68+
get_literal_unit(x) = screen_unit(something(__get_literal_unit(x), unitless))
69+
70+
"""
71+
Find the unit of a symbolic item.
72+
"""
73+
get_unit(x::Real) = unitless
74+
get_unit(x::DQ.AbstractQuantity) = screen_unit(oneunit(x))
75+
get_unit(x::AbstractArray) = map(get_unit, x)
76+
get_unit(x::Num) = get_unit(unwrap(x))
77+
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
78+
get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t)
79+
get_unit(op::typeof(getindex), args) = get_unit(args[1])
80+
get_unit(x::SciMLBase.NullParameters) = unitless
81+
get_unit(op::typeof(instream), args) = get_unit(args[1])
82+
83+
function get_unit(op, args) # Fallback
84+
result = op(get_unit.(args)...)
85+
try
86+
oneunit(result)
87+
catch
88+
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
89+
end
90+
end
91+
92+
function get_unit(op::Integral, args)
93+
unit = 1
94+
if op.domain.variables isa Vector
95+
for u in op.domain.variables
96+
unit *= get_unit(u)
97+
end
98+
else
99+
unit *= get_unit(op.domain.variables)
100+
end
101+
return oneunit(get_unit(args[1]) * unit)
102+
end
103+
104+
equivalent(x, y) = isequal(x, y)
105+
function get_unit(op::Conditional, args)
106+
terms = get_unit.(args)
107+
terms[1] == unitless ||
108+
throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless."))
109+
equivalent(terms[2], terms[3]) ||
110+
throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match."))
111+
return terms[2]
112+
end
113+
114+
function get_unit(op::typeof(Symbolics._mapreduce), args)
115+
if args[2] == +
116+
get_unit(args[3])
117+
else
118+
throw(ValidationError("Unsupported array operation $op"))
119+
end
120+
end
121+
122+
function get_unit(op::Comparison, args)
123+
terms = get_unit.(args)
124+
equivalent(terms[1], terms[2]) ||
125+
throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match."))
126+
return unitless
127+
end
128+
129+
function get_unit(x::Symbolic)
130+
if (u = __get_literal_unit(x)) !== nothing
131+
screen_unit(u)
132+
elseif issym(x)
133+
get_literal_unit(x)
134+
elseif isadd(x)
135+
terms = get_unit.(arguments(x))
136+
firstunit = terms[1]
137+
for other in terms[2:end]
138+
termlist = join(map(repr, terms), ", ")
139+
equivalent(other, firstunit) ||
140+
throw(ValidationError(", in sum $x, units [$termlist] do not match."))
141+
end
142+
return firstunit
143+
elseif ispow(x)
144+
pargs = arguments(x)
145+
base, expon = get_unit.(pargs)
146+
@assert oneunit(expon) == unitless
147+
if base == unitless
148+
unitless
149+
else
150+
pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2]
151+
end
152+
elseif istree(x)
153+
op = operation(x)
154+
if issym(op) || (istree(op) && istree(operation(op))) # Dependent variables, not function calls
155+
return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i]
156+
elseif istree(op) && !istree(operation(op))
157+
gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t)
158+
return screen_unit(getmetadata(gp, VariableUnit, unitless))
159+
end # Actual function calls:
160+
args = arguments(x)
161+
return get_unit(op, args)
162+
else # This function should only be reached by Terms, for which `istree` is true
163+
throw(ArgumentError("Unsupported value $x."))
164+
end
165+
end
166+
167+
"""
168+
Get unit of term, returning nothing & showing warning instead of throwing errors.
169+
"""
170+
function safe_get_unit(term, info)
171+
side = nothing
172+
try
173+
side = get_unit(term)
174+
catch err
175+
if err isa DQ.DimensionError
176+
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
177+
elseif err isa ValidationError
178+
@warn(info*err.message)
179+
elseif err isa MethodError
180+
@warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).")
181+
else
182+
rethrow()
183+
end
60184
end
185+
side
186+
end
187+
188+
function _validate(terms::Vector, labels::Vector{String}; info::String = "")
189+
valid = true
190+
first_unit = nothing
191+
first_label = nothing
192+
for (term, label) in zip(terms, labels)
193+
equnit = safe_get_unit(term, info * label)
194+
if equnit === nothing
195+
valid = false
196+
elseif !isequal(term, 0)
197+
if first_unit === nothing
198+
first_unit = equnit
199+
first_label = label
200+
elseif !equivalent(first_unit, equnit)
201+
valid = false
202+
@warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.")
203+
end
204+
end
205+
end
206+
valid
207+
end
208+
209+
function _validate(conn::Connection; info::String = "")
210+
valid = true
211+
syss = get_systems(conn)
212+
sys = first(syss)
213+
st = states(sys)
214+
for i in 2:length(syss)
215+
s = syss[i]
216+
sst = states(s)
217+
if length(st) != length(sst)
218+
valid = false
219+
@warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(st)) and $(length(sst)) states, cannor connect.")
220+
continue
221+
end
222+
for (i, x) in enumerate(st)
223+
j = findfirst(isequal(x), sst)
224+
if j == nothing
225+
valid = false
226+
@warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same states.")
227+
else
228+
aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i")
229+
bunit = safe_get_unit(sst[j], info * string(nameof(s)) * "#$j")
230+
if !equivalent(aunit, bunit)
231+
valid = false
232+
@warn("$info: connected system states $x and $(sst[j]) have mismatched units.")
233+
end
234+
end
235+
end
236+
end
237+
valid
238+
end
239+
240+
function validate(jump::Union{VariableRateJump,
241+
ConstantRateJump}, t::Symbolic;
242+
info::String = "")
243+
newinfo = replace(info, "eq." => "jump")
244+
_validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units
245+
validate(jump.affect!, info = newinfo)
246+
end
247+
248+
function validate(jump::MassActionJump, t::Symbolic; info::String = "")
249+
left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols
250+
net_symbols = [x[1] for x in jump.net_stoch]
251+
all_symbols = vcat(left_symbols, net_symbols)
252+
allgood = _validate(all_symbols, string.(all_symbols); info)
253+
n = sum(x -> x[2], jump.reactant_stoch, init = 0)
254+
base_unitful = all_symbols[1] #all same, get first
255+
allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)],
256+
["scaled_rates", "1/(t*reactants^$n))"]; info)
257+
end
258+
259+
function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Symbolic)
260+
labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"]
261+
all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3])
262+
end
263+
264+
function validate(eq::Equation; info::String = "")
265+
if typeof(eq.lhs) == Connection
266+
_validate(eq.rhs; info)
267+
else
268+
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
269+
end
270+
end
271+
function validate(eq::Equation,
272+
term::Union{Symbolic, DQ.AbstractQuantity, Num}; info::String = "")
273+
_validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info)
274+
end
275+
function validate(eq::Equation, terms::Vector; info::String = "")
276+
_validate(vcat([eq.lhs, eq.rhs], terms),
277+
vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info)
278+
end
279+
280+
"""
281+
Returns true iff units of equations are valid.
282+
"""
283+
function validate(eqs::Vector; info::String = "")
284+
all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)])
285+
end
286+
function validate(eqs::Vector, noise::Vector; info::String = "")
287+
all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx")
288+
for idx in 1:length(eqs)])
289+
end
290+
function validate(eqs::Vector, noise::Matrix; info::String = "")
291+
all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx")
292+
for idx in 1:length(eqs)])
293+
end
294+
function validate(eqs::Vector, term::Symbolic; info::String = "")
295+
all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)])
296+
end
297+
validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing
298+
299+
"""
300+
Throws error if units of equations are invalid.
301+
"""
302+
function check_units(::Val{:DynamicQuantities}, eqs...)
303+
validate(eqs...) ||
304+
throw(ValidationError("Some equations had invalid units. See warnings for details."))
61305
end

src/systems/validation.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module UnitfulUnitCheck
22

33
using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, IfElse, RecursiveArrayTools
44
using ..ModelingToolkit: ValidationError,
5-
ModelingToolkit, Connection, instream, JumpType, VariableUnit, get_systems
5+
ModelingToolkit, Connection, instream, JumpType, VariableUnit, get_systems,
6+
Conditional, Comparison
67
using Symbolics: Symbolic, value, issym, isadd, ismul, ispow
78
const MT = ModelingToolkit
89

@@ -39,10 +40,6 @@ MT = ModelingToolkit
3940
equivalent(x, y) = isequal(1 * x, 1 * y)
4041
const unitless = Unitful.unit(1)
4142

42-
#For dispatching get_unit
43-
const Conditional = Union{typeof(ifelse), typeof(IfElse.ifelse)}
44-
const Comparison = Union{typeof.([==, !=, , <, <=, , >, >=, ])...}
45-
4643
"""
4744
Find the unit of a symbolic item.
4845
"""

0 commit comments

Comments
 (0)