Skip to content

Commit 7ac307d

Browse files
authored
Support SubBasisLayout (#30)
* start sub * tests pass * Increase coverage * MappedBasisLayout
1 parent e67d37f commit 7ac307d

File tree

3 files changed

+151
-60
lines changed

3 files changed

+151
-60
lines changed

src/ContinuumArrays.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ module ContinuumArrays
22
using IntervalSets, LinearAlgebra, LazyArrays, FillArrays, BandedMatrices, QuasiArrays
33
import Base: @_inline_meta, @_propagate_inbounds_meta, axes, getindex, convert, prod, *, /, \, +, -, ==,
44
IndexStyle, IndexLinear, ==, OneTo, tail, similar, copyto!, copy,
5-
first, last, show, isempty, findfirst, findlast, findall, Slice
5+
first, last, show, isempty, findfirst, findlast, findall, Slice, union, minimum, maximum
66
import Base.Broadcast: materialize, BroadcastStyle, broadcasted
7-
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport,
8-
adjointlayout, LdivApplyStyle, arguments, broadcastlayout, lazy_getindex,
7+
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport,
8+
adjointlayout, LdivApplyStyle, arguments, _arguments, call, broadcastlayout, lazy_getindex,
99
sublayout, ApplyLayout, BroadcastLayout, combine_mul_styles
1010
import LinearAlgebra: pinv
1111
import BandedMatrices: AbstractBandedLayout, _BandedMatrix
@@ -16,7 +16,7 @@ import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclu
1616
ApplyQuasiArray, ApplyQuasiMatrix, LazyQuasiArrayApplyStyle, AbstractQuasiArrayApplyStyle,
1717
LazyQuasiArray, LazyQuasiVector, LazyQuasiMatrix, LazyLayout, LazyQuasiArrayStyle
1818

19-
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, fullmaterialize, ℵ₁, Inclusion, Basis, WeightedBasis, grid
19+
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, fullmaterialize, ℵ₁, Inclusion, Basis, WeightedBasis, grid, transform
2020

2121
####
2222
# Interval indexing support
@@ -48,7 +48,7 @@ for find in (:findfirst, :findlast)
4848
@eval $find(f::Base.Fix2{typeof(isequal)}, d::Inclusion) = f.x in d.domain ? f.x : nothing
4949
end
5050

51-
function findall(f::Base.Fix2{typeof(isequal)}, d::Inclusion)
51+
function findall(f::Base.Fix2{typeof(isequal)}, d::Inclusion)
5252
r = findfirst(f,d)
5353
r === nothing ? eltype(d)[] : [r]
5454
end
@@ -79,13 +79,13 @@ end
7979
AffineQuasiVector(A::AA, x::X, b::B) where {AA,X,B} =
8080
AffineQuasiVector{promote_type(eltype(AA), eltype(X), eltype(B)),AA,X,B}(A,x,b)
8181

82-
AffineQuasiVector(A, x) = AffineQuasiVector(A, x, zero(promote_type(eltype(A),eltype(x))))
82+
AffineQuasiVector(A, x) = AffineQuasiVector(A, x, zero(promote_type(eltype(A),eltype(x))))
8383
AffineQuasiVector(x) = AffineQuasiVector(one(eltype(x)), x)
8484

8585
AffineQuasiVector(A, x::AffineQuasiVector, b) = AffineQuasiVector(A*x.A, x.x, A*x.b .+ b)
8686

8787
axes(A::AffineQuasiVector) = axes(A.x)
88-
getindex(A::AffineQuasiVector, k::Number) = A.A*A.x[k] .+ A.b
88+
getindex(A::AffineQuasiVector, k::Number) = A.A*A.x[k] .+ A.b
8989
inbounds_getindex(A::AffineQuasiVector{<:Any,<:Any,<:Inclusion}, k::Number) = A.A*k .+ A.b
9090
isempty(A::AffineQuasiVector) = isempty(A.x)
9191
==(a::AffineQuasiVector, b::AffineQuasiVector) = a.A == b.A && a.x == b.x && a.b == b.b
@@ -116,6 +116,11 @@ for find in (:findfirst, :findlast, :findall)
116116
@eval $find(f::Base.Fix2{typeof(isequal)}, d::AffineQuasiVector) = $find(isequal(d.A \ (f.x .- d.b)), d.x)
117117
end
118118

119+
minimum(d::AffineQuasiVector{<:Real,<:Real,<:Inclusion}) = signbit(d.A) ? last(d) : first(d)
120+
maximum(d::AffineQuasiVector{<:Real,<:Real,<:Inclusion}) = signbit(d.A) ? first(d) : last(d)
121+
122+
union(d::AffineQuasiVector{<:Real,<:Real,<:Inclusion}) = Inclusion(minimum(d)..maximum(d))
123+
119124

120125

121126

src/bases/bases.jl

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,27 @@ abstract type Weight{T} <: LazyQuasiVector{T} end
55
const WeightedBasis{T, A<:AbstractQuasiVector, B<:Basis} = BroadcastQuasiMatrix{T,typeof(*),<:Tuple{A,B}}
66

77
struct WeightLayout <: MemoryLayout end
8-
struct BasisLayout <: MemoryLayout end
9-
struct AdjointBasisLayout <: MemoryLayout end
8+
abstract type AbstractBasisLayout <: MemoryLayout end
9+
struct BasisLayout <: AbstractBasisLayout end
10+
struct SubBasisLayout <: AbstractBasisLayout end
11+
struct MappedBasisLayout <: AbstractBasisLayout end
12+
13+
abstract type AbstractAdjointBasisLayout <: MemoryLayout end
14+
struct AdjointBasisLayout <: AbstractAdjointBasisLayout end
15+
struct AdjointSubBasisLayout <: AbstractAdjointBasisLayout end
16+
struct AdjointMappedBasisLayout <: AbstractAdjointBasisLayout end
1017

1118
MemoryLayout(::Type{<:Basis}) = BasisLayout()
1219
MemoryLayout(::Type{<:Weight}) = WeightLayout()
1320

1421
adjointlayout(::Type, ::BasisLayout) = AdjointBasisLayout()
15-
transposelayout(::Type{<:Real}, ::BasisLayout) = AdjointBasisLayout()
22+
adjointlayout(::Type, ::SubBasisLayout) = AdjointSubBasisLayout()
23+
adjointlayout(::Type, ::MappedBasisLayout) = AdjointMappedBasisLayout()
1624
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::BasisLayout) = BasisLayout()
25+
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::SubBasisLayout) = SubBasisLayout()
1726

18-
combine_mul_styles(::BasisLayout) = LazyQuasiArrayApplyStyle()
19-
combine_mul_styles(::AdjointBasisLayout) = LazyQuasiArrayApplyStyle()
27+
combine_mul_styles(::AbstractBasisLayout) = LazyQuasiArrayApplyStyle()
28+
combine_mul_styles(::AbstractAdjointBasisLayout) = LazyQuasiArrayApplyStyle()
2029

2130
ApplyStyle(::typeof(pinv), ::Type{<:Basis}) = LazyQuasiArrayApplyStyle()
2231
pinv(J::Basis) = apply(pinv,J)
@@ -25,7 +34,7 @@ _multup(a::Tuple) = Mul(a...)
2534
_multup(a) = a
2635

2736

28-
function ==(A::Basis, B::Basis)
37+
function ==(A::Basis, B::Basis)
2938
axes(A) == axes(B) && throw(ArgumentError("Override == to compare bases of type $(typeof(A)) and $(typeof(B))"))
3039
false
3140
end
@@ -36,42 +45,45 @@ ApplyStyle(::typeof(\), ::Type{<:Basis}, ::Type{<:AbstractQuasiVector}) = LdivAp
3645
ApplyStyle(::typeof(\), ::Type{<:SubQuasiArray{<:Any,2,<:Basis}}, ::Type{<:AbstractQuasiMatrix}) = LdivApplyStyle()
3746
ApplyStyle(::typeof(\), ::Type{<:SubQuasiArray{<:Any,2,<:Basis}}, ::Type{<:AbstractQuasiVector}) = LdivApplyStyle()
3847

39-
copy(L::Ldiv{BasisLayout,BroadcastLayout{typeof(+)}}) = +(broadcast(\,Ref(L.A),arguments(L.B))...)
40-
function copy(L::Ldiv{BasisLayout,BroadcastLayout{typeof(-)}})
48+
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)}}) = +(broadcast(\,Ref(L.A),arguments(L.B))...)
49+
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)},<:Any,<:AbstractQuasiVector}) =
50+
+(broadcast(\,Ref(L.A),arguments(L.B))...)
51+
52+
function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)}})
53+
a,b = arguments(L.B)
54+
(L.A\a)-(L.A\b)
55+
end
56+
57+
function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)},<:Any,<:AbstractQuasiVector})
4158
a,b = arguments(L.B)
4259
(L.A\a)-(L.A\b)
4360
end
4461

62+
function copy(P::Ldiv{BasisLayout,BasisLayout})
63+
A, B = P.A, P.B
64+
A == B || throw(ArgumentError("Override materialize for $(typeof(A)) \\ $(typeof(B))"))
65+
Eye(size(A,2))
66+
end
67+
function copy(P::Ldiv{SubBasisLayout,SubBasisLayout})
68+
A, B = P.A, P.B
69+
(parent(A) == parent(B) && parentindices(A) == parentindices(B)) ||
70+
throw(ArgumentError("Override materialize for $(typeof(A)) \\ $(typeof(B))"))
71+
Eye(size(A,2))
72+
end
73+
74+
function copy(P::Ldiv{MappedBasisLayout,MappedBasisLayout})
75+
A, B = P.A, P.B
76+
demap(A)\demap(B)
77+
end
78+
# function copy(P::Ldiv{MappedBasisLayout,SubBasisLayout})
79+
# A, B = P.A, P.B
80+
# # use lazy_getindex to avoid sparse arrays
81+
# lazy_getindex(parent(A)\parent(B),:,parentindices(B)[2])
82+
# end
83+
4584
for Bas1 in (:Basis, :WeightedBasis), Bas2 in (:Basis, :WeightedBasis)
46-
@eval begin
47-
function copy(P::Ldiv{<:Any,<:Any,<:$Bas1,<:$Bas2})
48-
A, B = P.A, P.B
49-
A == B || throw(ArgumentError("Override materialize for $(typeof(A)) \\ $(typeof(B))"))
50-
Eye(size(A,2))
51-
end
52-
function copy(P::Ldiv{<:Any,<:Any,<:SubQuasiArray{<:Any,2,<:$Bas1},<:SubQuasiArray{<:Any,2,<:$Bas2}})
53-
A, B = P.A, P.B
54-
(parent(A) == parent(B) && parentindices(A) == parentindices(B)) ||
55-
throw(ArgumentError("Override materialize for $(typeof(A)) \\ $(typeof(B))"))
56-
Eye(size(A,2))
57-
end
58-
59-
function copy(P::Ldiv{<:Any,<:Any,<:SubQuasiArray{<:Any,2,<:$Bas1,<:Tuple{<:AffineQuasiVector,<:Slice}},
60-
<:SubQuasiArray{<:Any,2,<:$Bas2,<:Tuple{<:AffineQuasiVector,<:Slice}}})
61-
A, B = P.A, P.B
62-
parent(A)\parent(B)
63-
end
64-
function copy(P::Ldiv{<:Any,<:Any,<:SubQuasiArray{<:Any,2,<:$Bas1,<:Tuple{<:AffineQuasiVector,<:Slice}},
65-
<:SubQuasiArray{<:Any,2,<:$Bas2,<:Tuple{<:AffineQuasiVector,<:Any}}})
66-
A, B = P.A, P.B
67-
# use lazy_getindex to avoid sparse arrays
68-
lazy_getindex(parent(A)\parent(B),:,parentindices(B)[2])
69-
end
70-
71-
function ==(A::SubQuasiArray{<:Any,2,<:$Bas1}, B::SubQuasiArray{<:Any,2,<:$Bas2})
72-
all(parentindices(A) == parentindices(B)) && parent(A) == parent(B)
73-
end
74-
end
85+
@eval ==(A::SubQuasiArray{<:Any,2,<:$Bas1}, B::SubQuasiArray{<:Any,2,<:$Bas2}) =
86+
all(parentindices(A) == parentindices(B)) && parent(A) == parent(B)
7587
end
7688

7789

@@ -82,15 +94,15 @@ function transform(L)
8294
p,L[p,:]
8395
end
8496

85-
function copy(L::Ldiv{BasisLayout,<:Any,<:Any,<:AbstractQuasiVector})
97+
function copy(L::Ldiv{<:AbstractBasisLayout,<:Any,<:Any,<:AbstractQuasiVector})
8698
p,T = transform(L.A)
8799
T \ L.B[p]
88100
end
89101

90-
copy(L::Ldiv{BasisLayout,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) =
102+
copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) =
91103
copy(Ldiv{LazyLayout,ApplyLayout{typeof(*)}}(L.A, L.B))
92104

93-
function copy(L::Ldiv{BasisLayout,BroadcastLayout{typeof(*)},<:AbstractQuasiMatrix,<:AbstractQuasiVector})
105+
function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:AbstractQuasiMatrix,<:AbstractQuasiVector})
94106
p,T = transform(L.A)
95107
T \ L.B[p]
96108
end
@@ -101,7 +113,7 @@ end
101113
# *(arguments(S)...)
102114

103115

104-
# Differentiation of sub-arrays
116+
# Differentiation of sub-arrays
105117
function copy(M::QMul2{<:Derivative,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}}})
106118
A, B = M.args
107119
P = parent(B)
@@ -115,7 +127,7 @@ function copy(M::QMul2{<:Derivative,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatri
115127
(Derivative(axes(P,1))*P*kr.A)[kr,jr]
116128
end
117129

118-
function copy(L::Ldiv{BasisLayout,BroadcastLayout{typeof(*)},<:AbstractQuasiMatrix})
130+
function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:AbstractQuasiMatrix})
119131
args = arguments(L.B)
120132
# this is a temporary hack
121133
if args isa Tuple{AbstractQuasiMatrix,Number}
@@ -129,8 +141,29 @@ end
129141

130142

131143
# we represent as a Mul with a banded matrix
132-
sublayout(::BasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractUnitRange}}) = ApplyLayout{typeof(*)}()
133-
sublayout(::BasisLayout, ::Type{<:Tuple{<:AffineQuasiVector,<:AbstractUnitRange}}) = BasisLayout()
144+
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractUnitRange}}) = SubBasisLayout()
145+
sublayout(::BasisLayout, ::Type{<:Tuple{<:AffineQuasiVector,<:AbstractUnitRange}}) = MappedBasisLayout()
146+
147+
demap(x) = x
148+
demap(V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Any,<:Slice}}) = parent(V)
149+
function demap(V::SubQuasiArray{<:Any,2})
150+
kr, jr = parentindices(V)
151+
demap(parent(V)[kr,:])[:,jr]
152+
end
153+
154+
155+
##
156+
# SubLayout behaves like ApplyLayout{typeof(*)}
157+
158+
combine_mul_styles(::SubBasisLayout) = combine_mul_styles(ApplyLayout{typeof(*)}())
159+
_arguments(::SubBasisLayout, A) = _arguments(ApplyLayout{typeof(*)}(), A)
160+
call(::SubBasisLayout, ::SubQuasiArray) = *
161+
162+
combine_mul_styles(::AdjointSubBasisLayout) = combine_mul_styles(ApplyLayout{typeof(*)}())
163+
_arguments(::AdjointSubBasisLayout, A) = _arguments(ApplyLayout{typeof(*)}(), A)
164+
arguments(::AdjointSubBasisLayout, A) = arguments(ApplyLayout{typeof(*)}(), A)
165+
call(::AdjointSubBasisLayout, ::SubQuasiArray) = *
166+
134167
function arguments(V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Inclusion,<:AbstractUnitRange}})
135168
A = parent(V)
136169
_,jr = parentindices(V)
@@ -139,8 +172,8 @@ function arguments(V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Inclusion,<:Abstract
139172
A,P
140173
end
141174

142-
143-
include("splines.jl")
175+
copy(L::Ldiv{BasisLayout,SubBasisLayout}) = apply(\, L.A, ApplyQuasiArray(L.B))
144176

145177

146178

179+
include("splines.jl")

test/runtests.jl

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ContinuumArrays, QuasiArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, ForwardDiff, Test
2-
import ContinuumArrays: ℵ₁, materialize, SimplifyStyle, AffineQuasiVector, BasisLayout, AdjointBasisLayout
2+
import ContinuumArrays: ℵ₁, materialize, SimplifyStyle, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, MappedBasisLayout
33
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
44
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout
55
import ForwardDiff: Dual
@@ -70,7 +70,7 @@ end
7070
@test L[2.1,2] 0.9
7171
@test L[2.1,3] == L'[3,2.1] == transpose(L)[3,2.1] 0.1
7272
@test_throws BoundsError L[3.1,2]
73-
73+
7474
@test L[[1.1,2.1], 1] == L'[1,[1.1,2.1]] == transpose(L)[1,[1.1,2.1]] [0.9,0.0]
7575
@test L[1.1,1:2] [0.9,0.1]
7676
@test L[[1.1,2.1], 1:2] [0.9 0.1; 0.0 0.9]
@@ -87,6 +87,34 @@ end
8787
@test δ'L [0.8, 0.2, 0.0]
8888

8989
@test L'L == SymTridiagonal([1/3,2/3,1/3], [1/6,1/6])
90+
91+
@testset "==" begin
92+
L = LinearSpline([1,2,3])
93+
H = HeavisideSpline([1,2,3])
94+
@test L == L
95+
@test L H
96+
H = HeavisideSpline([1,1.5,2.5,3])
97+
@test_throws ArgumentError L == H
98+
end
99+
100+
@testset "Adjoint layout" begin
101+
L = LinearSpline([1,2,3])
102+
@test MemoryLayout(typeof(L')) == AdjointBasisLayout()
103+
@test [3,4,5]'*L' isa ApplyQuasiArray
104+
end
105+
106+
@testset "Broadcast layout" begin
107+
L = LinearSpline([1,2,3])
108+
b = BroadcastQuasiArray(+, L*[3,4,5], L*[1.,2,3])
109+
@test (L\b) == [4,6,8]
110+
B = BroadcastQuasiArray(+, L, L)
111+
@test L\B == 2Eye(3)
112+
113+
b = BroadcastQuasiArray(-, L*[3,4,5], L*[1.,2,3])
114+
@test (L\b) == [2,2,2]
115+
B = BroadcastQuasiArray(-, L, L)
116+
@test L\B == 0Eye(3)
117+
end
90118
end
91119

92120
@testset "Derivative" begin
@@ -168,7 +196,7 @@ end
168196
@test_throws BoundsError L[0.1,1]
169197
@test_throws BoundsError L[1.1,0]
170198

171-
@test MemoryLayout(typeof(L[:,2:3])) isa ApplyLayout{typeof(*)}
199+
@test MemoryLayout(typeof(L[:,2:3])) isa SubBasisLayout
172200
@test L\L[:,2:3] isa BandedMatrix
173201
@test L\L[:,2:3] == [0 0; 1 0; 0 1.0; 0 0]
174202

@@ -222,8 +250,8 @@ end
222250
@test u[0.1] 0.00012678835289369413
223251
end
224252

225-
@testset "Change-of-variables" begin
226-
x = Inclusion(0..1)
253+
@testset "AffineQuasiVector" begin
254+
x = Inclusion(0..1)
227255
@test 2x isa AffineQuasiVector
228256
@test (2x)[0.1] == 0.2
229257
@test_throws BoundsError (2x)[2]
@@ -248,14 +276,33 @@ end
248276
@test findall(isequal(0.2),y) == [0.6]
249277
@test findall(isequal(2),y) == Float64[]
250278

279+
@test AffineQuasiVector(x)[0.1] == 0.1
280+
@test minimum(y) == -1
281+
@test maximum(y) == 1
282+
@test union(y) == Inclusion(-1..1)
283+
@test ContinuumArrays.inbounds_getindex(y,0.1) == y[0.1]
284+
@test ContinuumArrays.inbounds_getindex(y,2.1) == 2*2.1 - 1
285+
286+
z = 1 .- x
287+
@test minimum(z) == 0.0
288+
@test maximum(z) == 1.0
289+
@test union(z) == Inclusion(0..1)
290+
291+
@test !isempty(z)
292+
@test z == z
293+
end
294+
295+
@testset "Change-of-variables" begin
296+
x = Inclusion(0..1)
297+
y = 2x .- 1
251298
L = LinearSpline(range(-1,stop=1,length=10))
252299
@test L[y,:][0.1,:] == L[2*0.1-1,:]
253300

254301
D = Derivative(axes(L,1))
255302
H = HeavisideSpline(L.points)
256303
@test H\((D*L) * 2) (H\(D*L))*2 diagm(0 => fill(-9,9), 1 => fill(9,9))[1:end-1,:]
257304

258-
@test MemoryLayout(typeof(L[y,:])) isa BasisLayout
305+
@test MemoryLayout(typeof(L[y,:])) isa MappedBasisLayout
259306
a,b = arguments((D*L)[y,:])
260307
@test H[y,:]\a == Eye(9)
261308
@test H[y,:] \ (D*L)[y,:] isa BandedMatrix
@@ -264,4 +311,10 @@ end
264311
@test (D*L[y,:])[0.1,1] -9
265312
@test H[y,:] \ (D*L[y,:]) isa BandedMatrix
266313
@test H[y,:] \ (D*L[y,:]) diagm(0 => fill(-9,9), 1 => fill(9,9))[1:end-1,:]
267-
end
314+
315+
B = L[y,2:end-1]
316+
@test MemoryLayout(typeof(B)) isa MappedBasisLayout
317+
@test B[0.1,1] == L[2*0.1-1,2]
318+
@test B\B == Eye(8)
319+
@test L[y,:] \ B == Eye(10)[:,2:end-1]
320+
end

0 commit comments

Comments
 (0)