Skip to content

Commit efff2cd

Browse files
authored
Use extensions introduced in Julia 1.9 (#192)
* Update chainrulescore.jl * Update Project.toml * fix scalar indexing using ForwardDiff * a better way * GPUComponentArray * test * fix * jacobian * Use extension * Update RecursiveArrayToolsExt.jl * Update gpu_tests.jl * Update ComponentArrays.jl * Update ComponentArrays.jl * Update Project.toml * Update GPUArraysExt.jl * Update Project.toml * Update Project.toml * remove forwarddiffext * Update ComponentArrays.jl
1 parent a202b30 commit efff2cd

10 files changed

+87
-17
lines changed

Project.toml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,32 @@ version = "0.13.9"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
10+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
912
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1014
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
15+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
16+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1117
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
18+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
19+
20+
[weakdeps]
21+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
22+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
23+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
24+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
25+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
26+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
27+
28+
[extensions]
29+
ConstructionBaseExt = "ConstructionBase"
30+
GPUArraysExt = "GPUArrays"
31+
RecursiveArrayToolsExt = "RecursiveArrayTools"
32+
ReverseDiffExt = "ReverseDiff"
33+
SciMLBaseExt = "SciMLBase"
34+
StaticArraysExt = "StaticArrays"
1235

1336
[compat]
1437
ArrayInterface = "6, 7"
@@ -18,7 +41,13 @@ StaticArrayInterface = "1"
1841
julia = "1.6"
1942

2043
[extras]
44+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
2145
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
46+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
47+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
48+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
49+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
50+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2251
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2352

2453
[targets]

ext/ConstructionBaseExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module ConstructionBaseExt
2+
3+
using ComponentArrays
4+
isdefined(Base, :get_extension) ? (using ConstructionBase) : (using ..ConstructionBase)
5+
6+
ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...)
7+
8+
end

src/compat/gpuarrays.jl renamed to ext/GPUArraysExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
module GPUArraysExt
2+
3+
using ComponentArrays, LinearAlgebra
4+
isdefined(Base, :get_extension) ? (using GPUArrays) : (using ..GPUArrays)
5+
16
const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
27
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax}
38
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
@@ -271,3 +276,5 @@ function LinearAlgebra.mul!(C::GPUComponentVecorMat,
271276
}, a::Real, b::Real)
272277
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
273278
end
279+
280+
end
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
module RecursiveArrayToolsExt
2+
3+
using ComponentArrays
4+
isdefined(Base, :get_extension) ? (using RecursiveArrayTools) : (using ..RecursiveArrayTools)
5+
16
AVOA = RecursiveArrayTools.AbstractVectorOfArray
27

38
function Base.Array(VA::AVOA{T,N,A}) where {T,N,A<:AbstractVector{<:ComponentVector}}
49
return ComponentArray(reduce(hcat, VA.u), only(getaxes(VA.u[1])), FlatAxis())
5-
end
10+
end
11+
12+
end

src/compat/reversediff.jl renamed to ext/ReverseDiffExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
module ReverseDiffExt
2+
3+
using ComponentArrays
4+
isdefined(Base, :get_extension) ? (using ReverseDiff) : (using ..ReverseDiff)
5+
16
const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
27

38
maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape)
@@ -25,4 +30,6 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol)
2530
t = ReverseDiff.tape(tca)
2631
return maybe_tracked_array(val, der, t, (s,), tca)
2732
end
28-
end
33+
end
34+
35+
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
# Plotting stuff
2+
module SciMLBaseExt
3+
4+
using ComponentArrays
5+
isdefined(Base, :get_extension) ? (using SciMLBase) : (using ..SciMLBase)
6+
27
function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{T,N,C}) where {T,N,C<:AbstractVector{<:ComponentArray}}
38
if SciMLBase.has_syms(sol.prob.f)
49
return sol.prob.f.syms
510
else
611
return Symbol.(labels(sol.u[1]))
712
end
813
end
14+
15+
end
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} =
2-
ComponentArray(similar(A), ax...)
1+
module StaticArraysExt
32

3+
using ComponentArrays
4+
isdefined(Base, :get_extension) ? (using StaticArrays) : (using ..StaticArrays)
45

6+
ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} =
7+
ComponentArray(similar(A), ax...)
58

9+
end

src/ComponentArrays.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import ChainRulesCore
44
import StaticArrayInterface, ArrayInterface
55

66
using LinearAlgebra
7-
using Requires
7+
8+
if !isdefined(Base, :get_extension)
9+
using Requires
10+
end
811

912
const FlatIdx = Union{Integer, CartesianIndex, CartesianIndices, AbstractArray{<:Integer}}
1013
const FlatOrColonIdx = Union{FlatIdx, Colon}
@@ -49,16 +52,15 @@ export labels, label2index
4952

5053
include("compat/chainrulescore.jl")
5154

52-
53-
required(filename) = include(joinpath("compat", filename))
54-
5555
function __init__()
56-
@require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" required("constructionbase.jl")
57-
@require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" required("scimlbase.jl")
58-
@require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl")
59-
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl")
60-
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl")
61-
@require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl")
56+
@static if !isdefined(Base, :get_extension)
57+
@require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" include("../ext/ConstructionBaseExt.jl")
58+
@require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" include("../ext/SciMLBaseExt.jl")
59+
@require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" include("../ext/RecursiveArrayToolsExt.jl")
60+
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" include("../ext/StaticArraysExt.jl")
61+
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl")
62+
@require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" include("../ext/GPUArraysExt.jl")
63+
end
6264
end
6365

64-
end
66+
end

src/compat/constructionbase.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/similar_convert_copy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ Base.NamedTuple(x::ComponentVector) = _namedtuple(x)
7575

7676

7777
## AbstractAxis conversion and promotion
78-
Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax
78+
Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax

0 commit comments

Comments
 (0)