Skip to content

Commit 1a3f104

Browse files
Merge pull request #123 from SciML/preallocationtools_dep
Pick up a PreallocationTools.jl dependency
2 parents d2ca535 + c70d08f commit 1a3f104

File tree

4 files changed

+28
-1
lines changed

4 files changed

+28
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1212
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
14+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1315

1416
[compat]
1517
ArrayInterface = "2.8, 3.0, 4, 5, 6"
1618
ArrayInterfaceStaticArrays = "0.1"
1719
ChainRulesCore = "1"
1820
MacroTools = "0.5"
19-
StaticArrays = "0.10, 0.11, 0.12, 1.0"
21+
PreallocationTools = "0.4"
22+
RecursiveArrayTools = "2"
23+
StaticArrays = "1.0"
2024
julia = "1.6"
2125

2226
[extras]

src/LabelledArrays.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LabelledArrays
22

33
using LinearAlgebra, StaticArrays, ArrayInterface
4+
import RecursiveArrayTools, PreallocationTools
45

56
include("slarray.jl")
67
include("larray.jl")
@@ -63,6 +64,16 @@ function ArrayInterface.ismutable(::Type{<:LArray{T, N, Syms}}) where {T, N, Sym
6364
end
6465
ArrayInterface.can_setindex(::Type{<:SLArray}) = false
6566

67+
function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache,
68+
u::LArray{T, N, D, Syms}) where {T, N, D, Syms}
69+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
70+
if nelem > length(dc.dual_du)
71+
PreallocationTools.enlargedualcache!(dc, nelem)
72+
end
73+
_x = ArrayInterfaceCore.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
74+
LabelledArrays.LArray{T, N, D, Syms}(_x)
75+
end
76+
6677
export SLArray, LArray, SLVector, LVector, @SLVector, @LArray, @LVector, @SLArray
6778

6879
export @SLSliced, @LSliced

src/larray.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,12 @@ function Base.vcat(x::LArray, y::LArray)
366366
end
367367

368368
Base.elsize(::Type{<:LArray{T}}) where {T} = sizeof(T)
369+
370+
function RecursiveArrayTools.recursive_unitless_eltype(a::Type{LArray{T, N, D, Syms}}) where {
371+
T,
372+
N,
373+
D,
374+
Syms
375+
}
376+
LArray{typeof(one(T)), N, D, Syms}
377+
end

src/slarray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ function StaticArrays.similar_type(::Type{SLArray{S, T, N, L, Syms}}, T2,
3939
::Size{S}) where {S, T, N, L, Syms}
4040
SLArray{S, T2, N, L, Syms}
4141
end
42+
function RecursiveArrayTools.recursive_unitless_eltype(a::Type{T}) where {T <: SLArray}
43+
StaticArrays.similar_type(a, recursive_unitless_eltype(eltype(a)))
44+
end
4245

4346
## Named tuple to SLArray
4447
#=

0 commit comments

Comments
 (0)