Skip to content

Commit 354fec7

Browse files
committed
Add Mul, mass matrices for splines
1 parent 235ff78 commit 354fec7

File tree

11 files changed

+427
-314
lines changed

11 files changed

+427
-314
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
julia 0.7
22
IntervalSets
3+
LazyArrays

examples/BSplines.jl

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using ContinuumArrays, LazyArrays, IntervalSets
2-
import ContinuumArrays: AbstractAxisMatrix, ℵ₀
1+
using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, Test
2+
import ContinuumArrays: AbstractAxisMatrix, ℵ₀, materialize
33
import Base: axes, getindex, convert
44

55
struct Spline{order,T} <: AbstractAxisMatrix{T}
@@ -11,7 +11,7 @@ const HeavisideSpline{T} = Spline{0,T}
1111

1212
Spline{o}(pts::AbstractVector{T}) where {o,T} = Spline{o,float(T)}(pts)
1313

14-
axes(B::Spline{o}) where o = (first(B.points)..last(B.points), Base.OneTo(length(B.points)-o-1))
14+
axes(B::Spline{o}) where o = (first(B.points)..last(B.points), Base.OneTo(length(B.points)+o-1))
1515

1616
function getindex(B::LinearSpline{T}, x::Real, k::Int) where T
1717
x axes(B,1) && 1 k  size(B,2)|| throw(BoundsError())
@@ -37,34 +37,83 @@ function getindex(B::HeavisideSpline{T}, x::Real, k::Int) where T
3737
return one(T)
3838
end
3939

40-
# getindex(B::LinearSpline, ::Colon, k::Int) = Mul(B, [Zeros{Int}(k-1); 1; Zeros{Int}(size(B,2)-k)])
4140

42-
# function convert(::Type{SymTridiagonal}, AB::Mul{<:Any,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:LinearSpline},<:LinearSpline})
43-
# Ac,B = AB.A, AB.B
44-
# A = parent(Ac)
45-
# @assert A.points == B.points
46-
# x = A.points
47-
# SymTridiagonal(x, x/2) # TODO fix
48-
# end
41+
function convert(::Type{SymTridiagonal}, AB::Mul{T,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) where T
42+
Ac,B = AB.A, AB.B
43+
A = parent(Ac)
44+
@assert A.points == B.points
45+
x = A.points; n = length(x)
46+
dv = Vector{T}(undef, n)
47+
dv[1] = (x[2]-x[1])/3
48+
for k = 2:n-1
49+
dv[k] = (x[k+1]-x[k-1])/3
50+
end
51+
dv[n] = (x[n] - x[n-1])/3
52+
53+
SymTridiagonal(dv, diff(x)./6)
54+
end
4955
#
50-
# materialize(M::Mul{<:Any,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) =
51-
# convert(SymTridiagonal, M)
56+
materialize(M::Mul{<:Any,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:LinearSpline},<:LinearSpline}) =
57+
convert(SymTridiagonal, M)
58+
59+
function materialize(M::Mul{T,<:Any,<:Any,<:ContinuumArrays.Adjoint{<:Any,<:HeavisideSpline},<:HeavisideSpline}) where T
60+
Ac, B = M.A, M.B
61+
axes(Ac,2) == axes(B,1) || throw(DimensionMismatch("axes must be same"))
62+
A = parent(Ac)
63+
A.points == B.points || throw(ArgumentError("Cannot multiply incompatible splines"))
64+
Diagonal(diff(A.points))
65+
end
66+
5267

5368
## tests
5469

55-
B = HeavisideSpline([1,2,3])
56-
@test size(B) == (ℵ₀, 2)
70+
H = HeavisideSpline([1,2,3])
71+
@test size(H) == (ℵ₀, 2)
72+
73+
@test_throws BoundsError H[0.1, 1]
74+
@test H[1.1,1] === H'[1,1.1] === transpose(H)[1,1.1] === 1.0
75+
@test H[2.1,1] === H'[1,2.1] === transpose(H)[1,2.1] === 0.0
76+
@test H[1.1,2] === 0.0
77+
@test H[2.1,2] === 1.0
78+
@test_throws BoundsError H[2.1,3]
79+
@test_throws BoundsError H'[3,2.1]
80+
@test_throws BoundsError transpose(H)[3,2.1]
81+
@test_throws BoundsError H[3.1,2]
82+
83+
@test all(H[[1.1,2.1], 1] .=== H'[1,[1.1,2.1]] .=== transpose(H)[1,[1.1,2.1]] .=== [1.0,0.0])
84+
@test all(H[1.1,1:2] .=== [1.0,0.0])
85+
@test all(H[[1.1,2.1], 1:2] .=== [1.0 0.0; 0.0 1.0])
86+
87+
@test_throws BoundsError H[[0.1,2.1], 1]
88+
89+
90+
L = LinearSpline([1,2,3])
91+
@test size(L) == (ℵ₀, 3)
92+
93+
@test_throws BoundsError L[0.1, 1]
94+
@test L[1.1,1] == L'[1,1.1] == transpose(L)[1,1.1] 0.9
95+
@test L[2.1,1] === L'[1,2.1] === transpose(L)[1,2.1] === 0.0
96+
@test L[1.1,2] 0.1
97+
@test L[2.1,2] 0.9
98+
@test L[2.1,3] == L'[3,2.1] == transpose(L)[3,2.1] 0.1
99+
@test_throws BoundsError L[3.1,2]
100+
L[[1.1,2.1], 1]
101+
@test L[[1.1,2.1], 1] == L'[1,[1.1,2.1]] == transpose(L)[1,[1.1,2.1]] [0.9,0.0]
102+
@test L[1.1,1:2] [0.9,0.1]
103+
@test L[[1.1,2.1], 1:2] [0.9 0.1; 0.0 0.9]
104+
105+
@test_throws BoundsError L[[0.1,2.1], 1]
106+
57107

58-
@test_throws BoundsError B[0.1, 1]
59-
@test B[1.1,1] === 1.0
60-
@test B[2.1,1] === 0.0
61-
@test B[1.1,2] === 0.0
62-
@test B[2.1,2] === 1.0
63-
@test_throws BoundsError B[2.1,3]
64-
@test_throws BoundsError B[3.1,2]
108+
f = H*[1,2]
109+
@test axes(f) == (1.0..3.0,)
110+
@test f[1.1] 1
111+
@test f[2.1] 2
65112

66-
@test all(B[[1.1,2.1], 1] .=== [1.0,0.0])
67-
@test all(B[1.1,1:2] .=== [1.0,0.0])
68-
@test all(B[[1.1,2.1], 1:2] .=== [1.0 0.0; 0.0 1.0])
113+
f = L*[1,2,4]
114+
@test axes(f) == (1.0..3.0,)
115+
@test f[1.1] 1.1
116+
@test f[2.1] 2.2
69117

70-
@test_throws BoundsError B[[0.1,2.1], 1]
118+
@test H'H == Eye(2)
119+
@test L'L == SymTridiagonal([1/3,2/3,1/3], [1/6,1/6])

src/ContinuumArrays.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
module ContinuumArrays
2-
using Base, LinearAlgebra, IntervalSets
3-
import Base: getindex, size, axes, length, ==, isequal, iterate, CartesianIndices, LinearIndices,
4-
Indices, IndexStyle, getindex, setindex!, parent, vec, convert, similar, zero,
5-
map, eachindex
6-
import Base: @_inline_meta, DimOrInd, OneTo, @_propagate_inbounds_meta, @_noinline_meta,
7-
DimsInteger, error_if_canonical_getindex, @propagate_inbounds, _return_type, _default_type,
8-
_maybetail, tail, _getindex, _maybe_reshape, index_ndims, _unsafe_getindex,
9-
index_shape, to_shape, unsafe_length, @nloops, @ncall
2+
include("axisarrays/AbstractAxisArrays.jl")
3+
using .AbstractAxisArrays
104

115

12-
import LinearAlgebra: transpose, adjoint
6+
####
7+
# Interval indexing support
8+
####
139

14-
abstract type AbstractAxisArray{T,N} end
15-
AbstractAxisVector{T} = AbstractAxisArray{T,1}
16-
AbstractAxisMatrix{T} = AbstractAxisArray{T,2}
17-
AbstractAxisVecOrMat{T} = Union{AbstractAxisVector{T}, AbstractAxisMatrix{T}}
10+
using IntervalSets
11+
import .AbstractAxisArrays: _length, checkindex, Adjoint, Transpose
12+
import Base: @_inline_meta
1813

1914
struct ℵ₀ <: Number end
2015
_length(::AbstractInterval) = ℵ₀
21-
_length(d) = length(d)
2216

23-
size(A::AbstractAxisArray) = _length.(axes(A))
24-
axes(A::AbstractAxisArray) = error("Override axes for $(typeof(A))")
17+
checkindex(::Type{Bool}, inds::AbstractInterval, i::Real) = (leftendpoint(inds) <= i) & (i <= rightendpoint(inds))
18+
function checkindex(::Type{Bool}, inds::AbstractInterval, I::AbstractArray)
19+
@_inline_meta
20+
b = true
21+
for i in I
22+
b &= checkindex(Bool, inds, i)
23+
end
24+
b
25+
end
2526

26-
include("indices.jl")
27-
include("abstractaxisarray.jl")
28-
include("adjtrans.jl")
29-
include("multidimensional.jl")
3027
end

0 commit comments

Comments
 (0)