Skip to content

Commit 652504a

Browse files
Merge pull request #160 from AayushSabharwal/as/rat-stuff
refactor: add DiffEqArray constructor
2 parents 52d8c3f + 2a8477b commit 652504a

File tree

6 files changed

+39
-14
lines changed

6 files changed

+39
-14
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- Core
2020
version:
2121
- '1'
22-
- '1.6'
22+
- '1.10'
2323
steps:
2424
- uses: actions/checkout@v4
2525
- uses: julia-actions/setup-julia@v2

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ ChainRulesCore = "1"
1919
ForwardDiff = "0.10.3"
2020
MacroTools = "0.5"
2121
PreallocationTools = "0.4"
22-
RecursiveArrayTools = "2,3"
22+
RecursiveArrayTools = "3"
2323
StaticArrays = "1.0"
24-
julia = "1.6"
24+
julia = "1.10"
2525

2626
[extras]
2727
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2828
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
2929
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
30+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
3031
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3132

3233
[targets]

src/LabelledArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import RecursiveArrayTools, PreallocationTools, ForwardDiff
66
include("slarray.jl")
77
include("larray.jl")
88
include("chainrules.jl")
9+
include("diffeqarray.jl")
910

1011
# Common
1112
@generated function __getindex(x::Union{LArray, SLArray}, ::Val{s}) where {s}

src/diffeqarray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
for LArrayType in [LArray, SLArray]
2+
@eval function RecursiveArrayTools.DiffEqArray(vec::AbstractVector{<:$LArrayType},
3+
ts::AbstractVector,
4+
p = nothing)
5+
RecursiveArrayTools.DiffEqArray(vec, ts, p; variables = collect(symbols(vec[1])))
6+
end
7+
end

test/recursivearraytools.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using RecursiveArrayTools, LabelledArrays, Test
2+
3+
ABC = @SLVector (:a, :b, :c);
4+
A = ABC(1, 2, 3);
5+
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
6+
@test getindex(B, :a) == [1, 1]

test/runtests.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,27 @@ using StaticArrays
44
using InteractiveUtils
55
using ChainRulesTestUtils
66

7-
@time begin
8-
@time @testset "SLArrays" begin
9-
include("slarrays.jl")
10-
end
11-
@time @testset "LArrays" begin
12-
include("larrays.jl")
13-
end
14-
@time @testset "DiffEq" begin
15-
include("diffeq.jl")
7+
const GROUP = get(ENV, "GROUP", "All")
8+
9+
if GROUP == "All"
10+
@time begin
11+
@time @testset "SLArrays" begin
12+
include("slarrays.jl")
13+
end
14+
@time @testset "LArrays" begin
15+
include("larrays.jl")
16+
end
17+
@time @testset "DiffEq" begin
18+
include("diffeq.jl")
19+
end
20+
@time @testset "ChainRules" begin
21+
include("chainrules.jl")
22+
end
1623
end
17-
@time @testset "ChainRules" begin
18-
include("chainrules.jl")
24+
end
25+
26+
if GROUP == "All" || GROUP == "RecursiveArrayTools"
27+
@time @testset "RecursiveArrayTools" begin
28+
include("recursivearraytools.jl")
1929
end
2030
end

0 commit comments

Comments
 (0)