Skip to content

Commit 9508294

Browse files
Merge pull request #369 from AayushSabharwal/myb/batch
refactor: use getu for symbolic indexing
2 parents b568dde + d3c2ecd commit 9508294

File tree

6 files changed

+29
-71
lines changed

6 files changed

+29
-71
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
3232
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
3333
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
34+
- {user: SciML, repo: LabelledArrays.jl, group: RecursiveArrayTools}
3435
steps:
3536
- uses: actions/checkout@v4
3637
- uses: julia-actions/setup-julia@v2

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ FastBroadcast = "0.2.8"
4444
ForwardDiff = "0.10.19"
4545
GPUArraysCore = "0.1.1"
4646
IteratorInterfaceExtensions = "1"
47-
LabelledArrays = "1.15"
4847
LinearAlgebra = "1.10"
4948
Measurements = "2.3"
5049
MonteCarloMeasurements = "1.1"
@@ -60,7 +59,7 @@ StaticArrays = "1.6"
6059
StaticArraysCore = "1.4"
6160
Statistics = "1.10"
6261
StructArrays = "0.6.11"
63-
SymbolicIndexingInterface = "0.3.2"
62+
SymbolicIndexingInterface = "0.3.19"
6463
Tables = "1.11"
6564
Test = "1"
6665
Tracker = "0.2.15"
@@ -72,7 +71,6 @@ julia = "1.10"
7271
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7372
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
7473
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75-
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7674
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
7775
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
7876
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
@@ -88,4 +86,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
8886
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8987

9088
[targets]
91-
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
89+
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]

src/vector_of_array.jl

Lines changed: 18 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -352,67 +352,28 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
352352
end
353353

354354
# Symbolic Indexing Methods
355-
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
356-
if is_independent_variable(A, sym)
357-
return A.t
358-
elseif is_variable(A, sym)
359-
if constant_structure(A)
360-
return getindex.(A.u, variable_index(A, sym))
361-
else
362-
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
355+
for symtype in [ScalarSymbolic, ArraySymbolic]
356+
paramcheck = quote
357+
if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym)
358+
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
363359
end
364-
elseif is_parameter(A, sym)
365-
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
366-
elseif is_observed(A, sym)
367-
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
368-
else
369-
# NOTE: this is basically just for LabelledArrays. It's better if this
370-
# were an error. Should we make an extension for LabelledArrays handling
371-
# this case?
372-
return getindex.(A.u, sym)
373360
end
374-
end
375-
376-
Base.@propagate_inbounds function _getindex(
377-
A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
378-
if is_independent_variable(A, sym)
379-
return A.t[args...]
380-
elseif is_variable(A, sym)
381-
return A[sym][args...]
382-
elseif is_observed(A, sym)
383-
u = A.u[args...]
384-
t = A.t[args...]
385-
observed_fn = observed(A, sym)
386-
if t isa AbstractArray
387-
return observed_fn.(u, (parameter_values(A),), t)
388-
else
389-
return observed_fn(u, parameter_values(A), t)
390-
end
391-
else
392-
# NOTE: this is basically just for LabelledArrays. It's better if this
393-
# were an error. Should we make an extension for LabelledArrays handling
394-
# this case?
395-
return getindex.(A.u[args...], sym)
361+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym)
362+
$paramcheck
363+
getu(A, sym)(A)
396364
end
397-
end
398-
399-
Base.@propagate_inbounds function _getindex(
400-
A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
401-
return getindex(A, collect(sym), args...)
402-
end
403-
404-
Base.@propagate_inbounds function _getindex(
405-
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
406-
if all(x -> is_parameter(A, x), sym)
407-
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
408-
else
409-
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
365+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg)
366+
$paramcheck
367+
getu(A, sym)(A, arg)
368+
end
369+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
370+
$paramcheck
371+
getu(A, sym).((A,), arg)
372+
end
373+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon)
374+
$paramcheck
375+
getu(A, sym)(A)
410376
end
411-
end
412-
413-
Base.@propagate_inbounds function _getindex(
414-
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...)
415-
return reduce(vcat, map(s -> A[s, args...]', sym))
416377
end
417378

418379
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,

test/downstream/symbol_indexing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Test
2+
using Zygote
23
using ModelingToolkit: t_nounits as t, D_nounits as D
34

45
include("../testutils.jl")
@@ -35,10 +36,12 @@ sol_new = DiffEqArray(sol.u[1:10],
3536
@test_throws Exception sol_new[τ]
3637

3738
gs, = Zygote.gradient(sol) do sol
38-
sum(sol[fol_separate.x])
39+
sum(sol[fol_separate.x])
3940
end
4041

41-
@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u))
42+
@testset "Symbolic Indexing ADjoint" begin
43+
@test all(all.(isone, gs.u))
44+
end
4245

4346
# Tables interface
4447
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ end
4747
if GROUP == "SymbolicIndexingInterface" || GROUP == "Downstream"
4848
if GROUP == "SymbolicIndexingInterface"
4949
activate_downstream_env()
50-
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
5150
end
51+
@time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl")
5252
end
5353

5454
if GROUP == "GPU"

test/symbolic_indexing_interface_test.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, Test, LabelledArrays, SymbolicIndexingInterface
1+
using RecursiveArrayTools, Test, SymbolicIndexingInterface
22

33
t = 0.0:0.1:1.0
44
f(x) = 2x
@@ -20,7 +20,7 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
2020
@test dx[[:a, :b]] [[f(x), f2(x)] for x in t]
2121
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
2222
@test dx[[:a, :b], 3] [f(t[3]), f2(t[3])]
23-
@test dx[[:a, :b], 4:5] vcat(f.(t[4:5])', f2.(t[4:5])')
23+
@test dx[[:a, :b], 4:5] vcat.(f.(t[4:5]), f2.(t[4:5]))
2424
@test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]]
2525
@test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3]
2626
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
@@ -53,8 +53,3 @@ get_tuple = getu(dx, (:a, :b))
5353

5454
dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b])
5555
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym
56-
57-
ABC = @SLVector (:a, :b, :c);
58-
A = ABC(1, 2, 3);
59-
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
60-
@test getindex(B, :a) == [1, 1]

0 commit comments

Comments
 (0)