Skip to content

Commit ea53e12

Browse files
refactor: use getu for symbolic indexing, remove implicit LabelledArrays dependency
1 parent c664868 commit ea53e12

File tree

3 files changed

+21
-74
lines changed

3 files changed

+21
-74
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ FastBroadcast = "0.2.8"
4444
ForwardDiff = "0.10.19"
4545
GPUArraysCore = "0.1.1"
4646
IteratorInterfaceExtensions = "1"
47-
LabelledArrays = "1.15"
4847
LinearAlgebra = "1.10"
4948
Measurements = "2.3"
5049
MonteCarloMeasurements = "1.1"
@@ -60,7 +59,7 @@ StaticArrays = "1.6"
6059
StaticArraysCore = "1.4"
6160
Statistics = "1.10"
6261
StructArrays = "0.6.11"
63-
SymbolicIndexingInterface = "0.3.2"
62+
SymbolicIndexingInterface = "0.3.19"
6463
Tables = "1.11"
6564
Test = "1"
6665
Tracker = "0.2.15"
@@ -72,7 +71,6 @@ julia = "1.10"
7271
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7372
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
7473
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75-
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7674
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
7775
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
7876
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
@@ -88,4 +86,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
8886
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8987

9088
[targets]
91-
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
89+
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]

src/vector_of_array.jl

Lines changed: 17 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -352,73 +352,27 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
352352
end
353353

354354
# Symbolic Indexing Methods
355-
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
356-
if is_independent_variable(A, sym)
357-
return A.t
358-
elseif is_variable(A, sym)
359-
if constant_structure(A)
360-
return getindex.(A.u, variable_index(A, sym))
361-
else
362-
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
355+
for symtype in [ScalarSymbolic, ArraySymbolic]
356+
paramcheck = quote
357+
if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym)
358+
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
363359
end
364-
elseif is_parameter(A, sym)
365-
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
366-
elseif is_observed(A, sym)
367-
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
368-
else
369-
# NOTE: this is basically just for LabelledArrays. It's better if this
370-
# were an error. Should we make an extension for LabelledArrays handling
371-
# this case?
372-
return getindex.(A.u, sym)
373360
end
374-
end
375-
376-
Base.@propagate_inbounds function _getindex(
377-
A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
378-
if is_independent_variable(A, sym)
379-
return A.t[args...]
380-
elseif is_variable(A, sym)
381-
return A[sym][args...]
382-
elseif is_observed(A, sym)
383-
u = A.u[args...]
384-
t = A.t[args...]
385-
observed_fn = observed(A, sym)
386-
if t isa AbstractArray
387-
return observed_fn.(u, (parameter_values(A),), t)
388-
else
389-
return observed_fn(u, parameter_values(A), t)
390-
end
391-
else
392-
# NOTE: this is basically just for LabelledArrays. It's better if this
393-
# were an error. Should we make an extension for LabelledArrays handling
394-
# this case?
395-
return getindex.(A.u[args...], sym)
361+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym)
362+
$paramcheck
363+
getu(A, sym)(A)
396364
end
397-
end
398-
399-
Base.@propagate_inbounds function _getindex(
400-
A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
401-
return getindex(A, collect(sym), args...)
402-
end
403-
404-
Base.@propagate_inbounds function _getindex(
405-
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
406-
if all(x -> is_parameter(A, x), sym)
407-
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
408-
else
409-
return A[sym, eachindex(A.t)]
365+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg)
366+
$paramcheck
367+
getu(A, sym)(A, arg)
410368
end
411-
end
412-
413-
Base.@propagate_inbounds function _getindex(
414-
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...)
415-
u = A.u[args...]
416-
t = A.t[args...]
417-
observed_fn = observed(A, sym)
418-
if t isa AbstractArray
419-
return observed_fn.(u, (parameter_values(A),), t)
420-
else
421-
return observed_fn(u, parameter_values(A), t)
369+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
370+
$paramcheck
371+
getu(A, sym).((A,), arg)
372+
end
373+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon)
374+
$paramcheck
375+
getu(A, sym)(A)
422376
end
423377
end
424378

test/symbolic_indexing_interface_test.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, Test, LabelledArrays, SymbolicIndexingInterface
1+
using RecursiveArrayTools, Test, SymbolicIndexingInterface
22

33
t = 0.0:0.1:1.0
44
f(x) = 2x
@@ -20,7 +20,7 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
2020
@test dx[[:a, :b]] [[f(x), f2(x)] for x in t]
2121
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
2222
@test dx[[:a, :b], 3] [f(t[3]), f2(t[3])]
23-
@test dx[[:a, :b], 4:5] vcat(f.(t[4:5])', f2.(t[4:5])')
23+
@test dx[[:a, :b], 4:5] vcat.(f.(t[4:5]), f2.(t[4:5]))
2424
@test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]]
2525
@test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3]
2626
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
@@ -53,8 +53,3 @@ get_tuple = getu(dx, (:a, :b))
5353

5454
dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b])
5555
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym
56-
57-
ABC = @SLVector (:a, :b, :c);
58-
A = ABC(1, 2, 3);
59-
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
60-
@test getindex(B, :a) == [1, 1]

0 commit comments

Comments
 (0)