Skip to content

Commit 21146df

Browse files
committed
Add Derivative for spline
1 parent 602610b commit 21146df

File tree

7 files changed

+107
-35
lines changed

7 files changed

+107
-35
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
julia 0.7
22
IntervalSets
33
LazyArrays
4+
BandedMatrices

src/ContinuumArrays.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
module ContinuumArrays
2-
using IntervalSets, LinearAlgebra, LazyArrays
2+
using IntervalSets, LinearAlgebra, LazyArrays, BandedMatrices
33
import Base: @_inline_meta, axes, getindex, convert, prod, *, /, \, +, -,
44
IndexStyle, IndexLinear
55
import Base.Broadcast: materialize
6+
import LazyArrays: Mul2
7+
import BandedMatrices: AbstractBandedLayout
68

79
include("QuasiArrays/QuasiArrays.jl")
810
using .QuasiArrays
911
import .QuasiArrays: _length, checkindex, Adjoint, Transpose
1012

11-
export Spline, LinearSpline, HeavisideSpline, DiracDelta
13+
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative
1214

1315
####
1416
# Interval indexing support

src/QuasiArrays/QuasiArrays.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module QuasiArrays
22
using Base, LinearAlgebra, LazyArrays
33
import Base: getindex, size, axes, length, ==, isequal, iterate, CartesianIndices, LinearIndices,
44
Indices, IndexStyle, getindex, setindex!, parent, vec, convert, similar, zero,
5-
map, eachindex, eltype
5+
map, eachindex, eltype, first, last
66
import Base: @_inline_meta, DimOrInd, OneTo, @_propagate_inbounds_meta, @_noinline_meta,
77
DimsInteger, error_if_canonical_getindex, @propagate_inbounds, _return_type, _default_type,
88
_maybetail, tail, _getindex, _maybe_reshape, index_ndims, _unsafe_getindex,
@@ -14,7 +14,7 @@ import Base.Broadcast: materialize
1414

1515
import LinearAlgebra: transpose, adjoint, checkeltype_adjoint, checkeltype_transpose
1616

17-
import LazyArrays: MemoryLayout, UnknownLayout
17+
import LazyArrays: MemoryLayout, UnknownLayout, Mul2
1818

1919
export AbstractQuasiArray, AbstractQuasiMatrix, AbstractQuasiVector, materialize
2020

src/QuasiArrays/matmul.jl

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,44 @@
11

2-
@inline MemoryLayout(A::AbstractQuasiArray{T}) where T = UnknownLayout()
3-
4-
const QuasiArrayMulArray{TV, styleA, styleB, p, q, T, V} =
5-
Mul{TV, styleA, styleB, <:AbstractQuasiArray{T,p}, <:AbstractArray{V,q}}
2+
const QuasiArrayMulArray{styleA, styleB, p, q, T, V} =
3+
Mul2{styleA, styleB, <:AbstractQuasiArray{T,p}, <:AbstractArray{V,q}}
64

5+
const QuasiArrayMulQuasiArray{styleA, styleB, p, q, T, V} =
6+
Mul2{styleA, styleB, <:AbstractQuasiArray{T,p}, <:AbstractQuasiArray{V,q}}
77
####
88
# Matrix * Vector
99
####
10-
let (p,q) = (2,1)
11-
global const QuasiMatMulVec{TV, styleA, styleB, T, V} = QuasiArrayMulArray{TV, styleA, styleB, p, q, T, V}
12-
end
10+
const QuasiMatMulVec{styleA, styleB, T, V} = QuasiArrayMulArray{styleA, styleB, 2, 1, T, V}
1311

14-
axes(M::QuasiMatMulVec) = (axes(M.A,1),)
1512

16-
function getindex(M::QuasiMatMulVec{T}, k::Real) where T
17-
ret = zero(T)
18-
for j = 1:size(M.A,2)
19-
ret += M.A[k,j] * M.B[j]
13+
function getindex(M::QuasiMatMulVec, k::Real)
14+
A,B = M.factors
15+
ret = zero(eltype(M))
16+
@inbounds for j = 1:size(A,2)
17+
ret += A[k,j] * B[j]
2018
end
2119
ret
2220
end
2321

22+
QuasiMatMulMat{styleA, styleB, T, V} = QuasiArrayMulArray{styleA, styleB, 2, 2, T, V}
23+
QuasiMatMulQuasiMat{styleA, styleB, T, V} = QuasiArrayMulQuasiArray{styleA, styleB, 2, 2, T, V}
24+
25+
2426
*(A::AbstractQuasiArray, B::AbstractQuasiArray) = materialize(Mul(A,B))
2527
*(A::AbstractQuasiArray, B::AbstractArray) = materialize(Mul(A,B))
2628
*(A::AbstractArray, B::AbstractQuasiArray) = materialize(Mul(A,B))
2729
inv(A::AbstractQuasiArray) = materialize(Inv(A))
28-
*(A::Inv{<:Any,<:Any,<:AbstractQuasiArray}, B::AbstractQuasiArray) = materialize(Mul(A,B))
30+
*(A::Inv{<:Any,<:AbstractQuasiArray}, B::AbstractQuasiArray) = materialize(Mul(A,B))
31+
32+
33+
_Mul(A::Mul, B::Mul) = Mul(A.factors..., B.factors...)
34+
_Mul(A::Mul, B) = Mul(A.factors..., B)
35+
_Mul(A, B::Mul) = Mul(A, B.factors...)
36+
_Mul(A, B) = Mul(A, B)
37+
_lsimplify2(A, B...) = _Mul(A, _lsimplify(B...))
38+
_lsimplify2(A::Mul, B...) = _lsimplify2(A.factors..., B...)
39+
_lsimplify(A) = materialize(A)
40+
_lsimplify(A, B) = materialize(Mul(A,B))
41+
_lsimplify(A, B, C, D...) = _lsimplify2(materialize(Mul(A,B)), C, D...)
42+
lsimplify(M::Mul) = _lsimplify(M.factors...)
43+
44+
*(A::AbstractQuasiArray, B::Mul) = lsimplify(_Mul(A, B))

src/bases/splines.jl

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,68 @@ function getindex(B::HeavisideSpline{T}, x::Real, k::Int) where T
3535
end
3636

3737

38-
function convert(::Type{SymTridiagonal}, AB::Mul{T,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) where T
39-
Ac,B = AB.A, AB.B
38+
function copyto!(dest::SymTridiagonal, AB::Mul2{<:Any,<:Any,<:Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) where T
39+
Ac,B = AB.factors
4040
A = parent(Ac)
41-
@assert A.points == B.points
41+
A.points == B.points || throw(ArgumentError())
42+
dv,ev = dest.dv,dest.ev
4243
x = A.points; n = length(x)
43-
dv = Vector{T}(undef, n)
44+
length(dv) == n || throw(DimensionMismatch())
45+
4446
dv[1] = (x[2]-x[1])/3
45-
for k = 2:n-1
47+
@inbounds for k = 2:n-1
4648
dv[k] = (x[k+1]-x[k-1])/3
4749
end
4850
dv[n] = (x[n] - x[n-1])/3
4951

50-
SymTridiagonal(dv, diff(x)./6)
52+
@inbounds for k = 1:n-1
53+
ev[k] = (x[k+1]-x[k])/6
54+
end
55+
56+
dest
57+
end
58+
59+
## Mass matrix
60+
function similar(AB::Mul2{<:Any,<:Any,<:Adjoint{<:Any,<:LinearSpline},<:LinearSpline}, ::Type{T}) where T
61+
n = size(AB,1)
62+
SymTridiagonal(Vector{T}(undef, n), Vector{T}(undef, n-1))
5163
end
5264
#
53-
materialize(M::Mul{<:Any,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) =
54-
convert(SymTridiagonal, M)
65+
materialize(M::Mul2{<:Any,<:Any,<:Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) =
66+
copyto!(similar(M, eltype(M)), M)
5567

56-
function materialize(M::Mul{T,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:HeavisideSpline},<:HeavisideSpline}) where T
57-
Ac, B = M.A, M.B
68+
function materialize(M::Mul2{<:Any,<:Any,<:Adjoint{<:Any,<:HeavisideSpline},<:HeavisideSpline})
69+
Ac, B = M.factors
5870
axes(Ac,2) == axes(B,1) || throw(DimensionMismatch("axes must be same"))
5971
A = parent(Ac)
6072
A.points == B.points || throw(ArgumentError("Cannot multiply incompatible splines"))
6173
Diagonal(diff(A.points))
6274
end
75+
76+
## Derivative
77+
function copyto!(dest::Mul2{<:Any,<:Any,<:HeavisideSpline},
78+
M::Mul2{<:Any,<:Any,<:Derivative,<:LinearSpline})
79+
D, L = M.factors
80+
H, A = dest.factors
81+
x = H.points
82+
83+
axes(dest) == axes(M) || throw(DimensionMismatch("axes must be same"))
84+
x == L.points || throw(ArgumentError("Cannot multiply incompatible splines"))
85+
bandwidths(A) == (0,1) || throw(ArgumentError("Not implemented"))
86+
87+
d = diff(x)
88+
A[band(0)] .= (-).(d)
89+
A[band(1)] .= d
90+
91+
dest
92+
end
93+
94+
function similar(M::Mul2{<:Any,<:Any,<:Derivative,<:LinearSpline}, ::Type{T}) where T
95+
D, B = M.factors
96+
n = size(B,2)
97+
Mul(HeavisideSpline{T}(B.points),
98+
BandedMatrix{T}(undef, (n-1,n), (0,1)))
99+
end
100+
101+
materialize(M::Mul2{<:Any,<:Any,<:Derivative,<:LinearSpline}) =
102+
copyto!(similar(M, eltype(M)), M)

src/operators.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,23 @@ function getindex(δ::DiracDelta{T}, x::Real) where T
1717
end
1818

1919

20-
function materialize(M::Mul{<:Any,<:Any,<:Any,<:QuasiArrays.Adjoint{<:Any,<:DiracDelta},<:AbstractQuasiVector})
20+
function materialize(M::Mul2{<:Any,<:Any,<:QuasiArrays.Adjoint{<:Any,<:DiracDelta},<:AbstractQuasiVector})
2121
A, B = M.A, M.B
2222
axes(A,2) == axes(A,1) || throw(DimensionMismatch())
2323
B[parent(A).x]
2424
end
2525

26-
function materialize(M::Mul{<:Any,<:Any,<:Any,<:QuasiArrays.Adjoint{<:Any,<:DiracDelta},<:AbstractQuasiMatrix})
27-
A, B = M.A, M.B
26+
function materialize(M::Mul2{<:Any,<:Any,<:QuasiArrays.Adjoint{<:Any,<:DiracDelta},<:AbstractQuasiMatrix})
27+
A, B = M.factors
2828
axes(A,2) == axes(B,1) || throw(DimensionMismatch())
2929
B[parent(A).x,:]
3030
end
31+
32+
struct Derivative{T,A} <: AbstractQuasiVector{T}
33+
axis::A
34+
end
35+
36+
Derivative{T}(axis::A) where {T,A} = Derivative{T,A}(axis)
37+
Derivative(axis) = Derivative{Float64}(axis)
38+
39+
axes(D::Derivative) = (D.axis, D.axis)

test/runtests.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, Test
1010
@test Base.IndexStyle(δ) === Base.IndexLinear()
1111
end
1212

13-
14-
1513
@testset "HeavisideSpline" begin
1614
H = HeavisideSpline([1,2,3])
1715
@test axes(H) === (axes(H,1),axes(H,2)) === (1.0..3.0, Base.OneTo(2))
@@ -69,9 +67,15 @@ end
6967
L = LinearSpline([1,2,3])
7068
@test δ'L [0.8, 0.2, 0.0]
7169

70+
@test L'L == SymTridiagonal([1/3,2/3,1/3], [1/6,1/6])
7271
end
7372

73+
@testset "Derivative" begin
74+
L = LinearSpline([1,2,3])
75+
f = L*[1,2,4]
76+
D = Derivative(axes(L,1))
77+
fp = D*f
7478

75-
76-
77-
@test L'L == SymTridiagonal([1/3,2/3,1/3], [1/6,1/6])
79+
@test fp[1.1] 1
80+
@test fp[2.2] 2
81+
end

0 commit comments

Comments
 (0)