Skip to content

Commit e289553

Browse files
authored
Simplify Ldiv (#60)
* Update bases.jl * WeightedBasis interface * sublayout for Weight * more bases * ADd SubWeightedBasis * Increase coverage * v0.3.2 * MappedWeightedBasisLayout * Update bases.jl * Update bases.jl * Move out BroadcastStyle overload * Update Project.toml * Increase coverage
1 parent 5bdd8fb commit e289553

File tree

5 files changed

+485
-399
lines changed

5 files changed

+485
-399
lines changed

Project.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.3.1"
3+
version = "0.3.2"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -13,18 +13,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
QuasiArrays = "c4ea9172-b204-11e9-377d-29865faadc5c"
1414

1515
[compat]
16-
ArrayLayouts = "0.4.3"
16+
ArrayLayouts = "0.4.7"
1717
BandedMatrices = "0.15.17"
1818
FillArrays = "0.9.3"
1919
InfiniteArrays = "0.8"
20-
IntervalSets = "0.3.2, 0.4, 0.5"
21-
LazyArrays = "0.17.1"
22-
QuasiArrays = "0.3.1"
20+
IntervalSets = "0.4, 0.5"
21+
LazyArrays = "0.18"
22+
QuasiArrays = "0.3.4"
2323
julia = "1.5"
2424

2525
[extras]
26+
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2627
FastTransforms = "057dd010-8810-581a-b7be-e3fc3b93f78c"
2728
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2829

2930
[targets]
30-
test = ["FastTransforms", "Test"]
31+
test = ["Base64", "FastTransforms", "Test"]

src/ContinuumArrays.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function dot(x::Inclusion{T,<:AbstractInterval}, y::Inclusion{V,<:AbstractInterv
8888
a,b = endpoints(x.domain)
8989
convert(TV, b^3 - a^3)/3
9090
end
91-
91+
9292

9393
for find in (:findfirst, :findlast)
9494
@eval $find(f::Base.Fix2{typeof(isequal)}, d::Inclusion) = f.x in d.domain ? f.x : nothing
@@ -106,18 +106,15 @@ function checkindex(::Type{Bool}, inds::Inclusion{<:Any,<:AbstractInterval}, r::
106106
end
107107

108108

109-
BroadcastStyle(::Type{<:Inclusion}) = LazyQuasiArrayStyle{1}()
110-
BroadcastStyle(::Type{<:QuasiAdjoint{<:Any,<:Inclusion}}) = LazyQuasiArrayStyle{2}()
111-
BroadcastStyle(::Type{<:QuasiTranspose{<:Any,<:Inclusion}}) = LazyQuasiArrayStyle{2}()
112-
113-
114109
###
115110
# Maps
116111
###
117112

118113
# Affine map represents A*x .+ b
119114
abstract type AbstractAffineQuasiVector{T,AA,X,B} <: AbstractQuasiVector{T} end
120115

116+
show(io::IO, ::MIME"text/plain", a::AbstractAffineQuasiVector) = print(io, "$(a.A) * $(a.x) .+ ($(a.b))")
117+
121118
struct AffineQuasiVector{T,AA,X,B} <: AbstractAffineQuasiVector{T,AA,X,B}
122119
A::AA
123120
x::X
@@ -135,7 +132,7 @@ AffineQuasiVector(A, x::AffineQuasiVector, b) = AffineQuasiVector(A*x.A, x.x, A*
135132
axes(A::AbstractAffineQuasiVector) = axes(A.x)
136133
affine_getindex(A, k) = A.A*A.x[k] .+ A.b
137134
getindex(A::AbstractAffineQuasiVector, k::Number) = affine_getindex(A, k)
138-
function getindex(A::AbstractAffineQuasiVector, k::Inclusion)
135+
function getindex(A::AbstractAffineQuasiVector, k::Inclusion)
139136
@boundscheck A.x[k] # throws bounds error if k ≠ x
140137
A
141138
end
@@ -188,7 +185,7 @@ struct AffineMap{T,D,R} <: AbstractAffineQuasiVector{T,T,D,T}
188185
range::R
189186
end
190187

191-
AffineMap(domain::AbstractQuasiVector{T}, range::AbstractQuasiVector{V}) where {T,V} =
188+
AffineMap(domain::AbstractQuasiVector{T}, range::AbstractQuasiVector{V}) where {T,V} =
192189
AffineMap{promote_type(T,V), typeof(domain),typeof(range)}(domain,range)
193190

194191
measure(x::Inclusion) = last(x)-first(x)
@@ -223,6 +220,19 @@ affine(a::AbstractQuasiVector, b) = affine(a, Inclusion(b))
223220
affine(a, b) = affine(Inclusion(a), Inclusion(b))
224221

225222

223+
# mapped vectors
224+
const AffineMappedQuasiVector = SubQuasiArray{<:Any, 1, <:Any, <:Tuple{AbstractAffineQuasiVector}}
225+
const AffineMappedQuasiMatrix = SubQuasiArray{<:Any, 2, <:Any, <:Tuple{AbstractAffineQuasiVector,Slice}}
226+
227+
==(a::AffineMappedQuasiVector, b::AffineMappedQuasiVector) = parentindices(a) == parentindices(b) && parent(a) == parent(b)
228+
229+
_sum(V::AffineMappedQuasiVector, ::Colon) = parentindices(V)[1].A \ sum(parent(V))
230+
231+
# pretty print for bases
232+
show(io::IO, P::AffineMappedQuasiMatrix) = print(io, "$(parent(P)) affine mapped to $(parentindices(P)[1].x.domain)")
233+
show(io::IO, P::AffineMappedQuasiVector) = print(io, "$(parent(P)) affine mapped to $(parentindices(P)[1].x.domain)")
234+
show(io::IO, ::MIME"text/plain", P::AffineMappedQuasiMatrix) = show(io, P)
235+
show(io::IO, ::MIME"text/plain", P::AffineMappedQuasiVector) = show(io, P)
226236

227237
const QInfAxes = Union{Inclusion,AbstractAffineQuasiVector}
228238

src/bases/bases.jl

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ struct BasisLayout <: AbstractBasisLayout end
1010
struct SubBasisLayout <: AbstractBasisLayout end
1111
struct MappedBasisLayout <: AbstractBasisLayout end
1212
struct WeightedBasisLayout <: AbstractBasisLayout end
13+
struct SubWeightedBasisLayout <: AbstractBasisLayout end
14+
struct MappedWeightedBasisLayout <: AbstractBasisLayout end
15+
16+
SubBasisLayouts = Union{SubBasisLayout,SubWeightedBasisLayout}
17+
WeightedBasisLayouts = Union{WeightedBasisLayout,SubWeightedBasisLayout,MappedWeightedBasisLayout}
18+
MappedBasisLayouts = Union{MappedBasisLayout,MappedWeightedBasisLayout}
1319

1420
abstract type AbstractAdjointBasisLayout <: AbstractQuasiLazyLayout end
1521
struct AdjointBasisLayout <: AbstractAdjointBasisLayout end
@@ -19,11 +25,21 @@ struct AdjointMappedBasisLayout <: AbstractAdjointBasisLayout end
1925
MemoryLayout(::Type{<:Basis}) = BasisLayout()
2026
MemoryLayout(::Type{<:Weight}) = WeightLayout()
2127

22-
adjointlayout(::Type, ::BasisLayout) = AdjointBasisLayout()
28+
adjointlayout(::Type, ::AbstractBasisLayout) = AdjointBasisLayout()
2329
adjointlayout(::Type, ::SubBasisLayout) = AdjointSubBasisLayout()
24-
adjointlayout(::Type, ::MappedBasisLayout) = AdjointMappedBasisLayout()
30+
adjointlayout(::Type, ::MappedBasisLayouts) = AdjointMappedBasisLayout()
2531
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::BasisLayout) = WeightedBasisLayout()
2632
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::SubBasisLayout) = WeightedBasisLayout()
33+
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::MappedBasisLayouts) = MappedWeightedBasisLayout()
34+
35+
# A sub of a weight is still a weight
36+
sublayout(::WeightLayout, _) = WeightLayout()
37+
38+
## Weighted basis interface
39+
unweightedbasis(P::BroadcastQuasiMatrix{<:Any,typeof(*),<:Tuple{AbstractQuasiVector,AbstractQuasiMatrix}}) = last(P.args)
40+
unweightedbasis(V::SubQuasiArray) = view(unweightedbasis(parent(V)), parentindices(V)...)
41+
42+
2743

2844
# Default is lazy
2945
ApplyStyle(::typeof(pinv), ::Type{<:Basis}) = LazyQuasiArrayApplyStyle()
@@ -48,25 +64,25 @@ end
4864
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)},<:Any,<:AbstractQuasiVector}) =
4965
transform_ldiv(L.A, L.B)
5066

51-
function copy(P::Ldiv{BasisLayout,BasisLayout})
67+
function copy(P::Ldiv{<:AbstractBasisLayout,<:AbstractBasisLayout})
5268
A, B = P.A, P.B
5369
A == B || throw(ArgumentError("Override copy for $(typeof(A)) \\ $(typeof(B))"))
5470
SquareEye{eltype(P)}((axes(A,2),))
5571
end
56-
function copy(P::Ldiv{SubBasisLayout,SubBasisLayout})
72+
function copy(P::Ldiv{<:SubBasisLayouts,<:SubBasisLayouts})
5773
A, B = P.A, P.B
5874
parent(A) == parent(B) ||
5975
throw(ArgumentError("Override copy for $(typeof(A)) \\ $(typeof(B))"))
6076
Eye{eltype(P)}((axes(A,2),axes(B,2)))
6177
end
6278

63-
@inline function copy(P::Ldiv{MappedBasisLayout,MappedBasisLayout})
79+
@inline function copy(P::Ldiv{<:MappedBasisLayouts,<:MappedBasisLayouts})
6480
A, B = P.A, P.B
6581
demap(A)\demap(B)
6682
end
6783

68-
@inline copy(L::Ldiv{BasisLayout,SubBasisLayout}) = apply(\, L.A, ApplyQuasiArray(L.B))
69-
@inline function copy(L::Ldiv{SubBasisLayout,BasisLayout})
84+
@inline copy(L::Ldiv{<:AbstractBasisLayout,<:SubBasisLayouts}) = apply(\, L.A, ApplyQuasiArray(L.B))
85+
@inline function copy(L::Ldiv{<:SubBasisLayouts,<:AbstractBasisLayout})
7086
P = parent(L.A)
7187
kr, jr = parentindices(L.A)
7288
layout_getindex(apply(\, P, L.B), jr, :) # avoid sparse arrays
@@ -83,7 +99,7 @@ end
8399
_grid(_, P) = error("Overload Grid")
84100
_grid(::MappedBasisLayout, P) = igetindex.(Ref(parentindices(P)[1]), grid(demap(P)))
85101
_grid(::SubBasisLayout, P) = grid(parent(P))
86-
_grid(::WeightedBasisLayout, P) = grid(last(P.args))
102+
_grid(::WeightedBasisLayouts, P) = grid(unweightedbasis(P))
87103
grid(P) = _grid(MemoryLayout(typeof(P)), P)
88104

89105
struct TransformFactorization{T,Grid,Plan,IPlan} <: Factorization{T}
@@ -121,6 +137,11 @@ end
121137
\(a::ProjectionFactorization, b::AbstractVector) = (a.F \ b)[a.inds]
122138

123139
_factorize(::SubBasisLayout, L) = ProjectionFactorization(factorize(parent(L)), parentindices(L)[2])
140+
# function _factorize(::MappedBasisLayout, L)
141+
# kr, jr = parentindices(L)
142+
# P = parent(L)
143+
# ProjectionFactorization(factorize(view(P,:,jr)), parentindices(L)[2])
144+
# end
124145

125146
transform_ldiv(A, B, _) = factorize(A) \ B
126147
transform_ldiv(A, B) = transform_ldiv(A, B, axes(A))
@@ -131,18 +152,6 @@ copy(L::Ldiv{<:AbstractBasisLayout,<:Any,<:Any,<:AbstractQuasiVector}) =
131152
copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) =
132153
transform_ldiv(L.A, L.B)
133154

134-
function copy(L::Ldiv{ApplyLayout{typeof(*)},<:AbstractBasisLayout})
135-
args = arguments(ApplyLayout{typeof(*)}(), L.A)
136-
@assert length(args) == 2 # temporary
137-
apply(\, last(args), apply(\, first(args), L.B))
138-
end
139-
140-
141-
function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:AbstractQuasiMatrix,<:AbstractQuasiVector})
142-
p,T = factorize(L.A)
143-
T \ L.B[p]
144-
end
145-
146155

147156
##
148157
# Algebra
@@ -151,9 +160,10 @@ end
151160
# struct ExpansionLayout <: MemoryLayout end
152161
# applylayout(::Type{typeof(*)}, ::BasisLayout, _) = ExpansionLayout()
153162

154-
const Expansion{T,Space<:Basis,Coeffs<:AbstractVector} = ApplyQuasiVector{T,typeof(*),<:Tuple{Space,Coeffs}}
163+
const Expansion{T,Space<:AbstractQuasiMatrix,Coeffs<:AbstractVector} = ApplyQuasiVector{T,typeof(*),<:Tuple{Space,Coeffs}}
155164

156-
basis(v::AbstractQuasiVector) = v.args[1]
165+
166+
basis(v::Expansion) = v.args[1]
157167

158168
for op in (:*, :\)
159169
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), x::Number, f::Expansion)
@@ -169,11 +179,19 @@ for op in (:*, :/)
169179
end
170180

171181

172-
function broadcastbasis(::typeof(+), a, b)
173-
a b && error("Overload broadcastbasis(::typeof(+), ::$(typeof(a)), ::$(typeof(b)))")
182+
function _broadcastbasis(::typeof(+), _, _, a, b)
183+
try
184+
a b && error("Overload broadcastbasis(::typeof(+), ::$(typeof(a)), ::$(typeof(b)))")
185+
catch
186+
error("Overload broadcastbasis(::typeof(+), ::$(typeof(a)), ::$(typeof(b)))")
187+
end
174188
a
175189
end
176190

191+
_broadcastbasis(::typeof(+), ::MappedBasisLayouts, ::MappedBasisLayouts, a, b) = broadcastbasis(+, demap(a), demap(b))[basismap(a), :]
192+
193+
broadcastbasis(::typeof(+), a, b) = _broadcastbasis(+, MemoryLayout(a), MemoryLayout(b), a, b)
194+
177195
broadcastbasis(::typeof(-), a, b) = broadcastbasis(+, a, b)
178196

179197
for op in (:+, :-)
@@ -228,48 +246,43 @@ end
228246
(Derivative(axes(P,1))*P*kr.A)[kr,jr]
229247
end
230248

231-
function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:AbstractQuasiMatrix})
232-
args = arguments(L.B)
233-
# this is a temporary hack
234-
if args isa Tuple{AbstractQuasiMatrix,Number}
235-
(L.A \ first(args))*last(args)
236-
elseif args isa Tuple{Number,AbstractQuasiMatrix}
237-
first(args)*(L.A \ last(args))
238-
else
239-
error("Not implemented")
240-
end
241-
end
242-
243-
244249
# we represent as a Mul with a banded matrix
245250
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractUnitRange}}) = SubBasisLayout()
246251
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:AbstractAffineQuasiVector,<:AbstractUnitRange}}) = MappedBasisLayout()
252+
sublayout(::WeightedBasisLayout, ::Type{<:Tuple{<:AbstractAffineQuasiVector,<:AbstractUnitRange}}) = MappedWeightedBasisLayout()
253+
sublayout(::WeightedBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractUnitRange}}) = SubWeightedBasisLayout()
247254

248255
@inline sub_materialize(::AbstractBasisLayout, V::AbstractQuasiArray) = V
249256
@inline sub_materialize(::AbstractBasisLayout, V::AbstractArray) = V
250257

251258
demap(x) = x
252-
demap(V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Any,<:Slice}}) = parent(V)
259+
demap(x::BroadcastQuasiArray) = BroadcastQuasiArray(x.f, map(demap, arguments(x))...)
260+
demap(V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:AbstractAffineQuasiVector,<:Slice}}) = parent(V)
261+
demap(V::SubQuasiArray{<:Any,1,<:Any,<:Tuple{<:AbstractAffineQuasiVector}}) = parent(V)
253262
function demap(V::SubQuasiArray{<:Any,2})
254263
kr, jr = parentindices(V)
255264
demap(parent(V)[kr,:])[:,jr]
256265
end
257266

267+
basismap(x::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}) = parentindices(x)[1]
268+
basismap(x::SubQuasiArray{<:Any,1,<:Any,<:Tuple{<:AbstractAffineQuasiVector}}) = parentindices(x)[1]
269+
basismap(x::BroadcastQuasiArray) = basismap(x.args[1])
270+
258271

259272
##
260273
# SubLayout behaves like ApplyLayout{typeof(*)}
261274

262-
combine_mul_styles(::SubBasisLayout) = combine_mul_styles(ApplyLayout{typeof(*)}())
263-
_mul_arguments(::SubBasisLayout, A) = _mul_arguments(ApplyLayout{typeof(*)}(), A)
264-
arguments(::SubBasisLayout, A) = arguments(ApplyLayout{typeof(*)}(), A)
265-
call(::SubBasisLayout, ::SubQuasiArray) = *
275+
combine_mul_styles(::SubBasisLayouts) = combine_mul_styles(ApplyLayout{typeof(*)}())
276+
_mul_arguments(::SubBasisLayouts, A) = _mul_arguments(ApplyLayout{typeof(*)}(), A)
277+
arguments(::SubBasisLayouts, A) = arguments(ApplyLayout{typeof(*)}(), A)
278+
call(::SubBasisLayouts, ::SubQuasiArray) = *
266279

267280
combine_mul_styles(::AdjointSubBasisLayout) = combine_mul_styles(ApplyLayout{typeof(*)}())
268281
_mul_arguments(::AdjointSubBasisLayout, A) = _mul_arguments(ApplyLayout{typeof(*)}(), A)
269282
arguments(::AdjointSubBasisLayout, A) = arguments(ApplyLayout{typeof(*)}(), A)
270283
call(::AdjointSubBasisLayout, ::SubQuasiArray) = *
271284

272-
copy(M::Mul{AdjointSubBasisLayout,SubBasisLayout}) = apply(*, arguments(M.A)..., arguments(M.B)...)
285+
copy(M::Mul{AdjointSubBasisLayout,<:SubBasisLayouts}) = apply(*, arguments(M.A)..., arguments(M.B)...)
273286

274287
function arguments(::ApplyLayout{typeof(*)}, V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Inclusion,<:AbstractUnitRange}})
275288
A = parent(V)
@@ -297,8 +310,8 @@ function __sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, ::Colon)
297310
first(apply(*, sum(a[1]; dims=1), tail(a)...))
298311
end
299312

300-
function __sum(::MappedBasisLayout, V::AbstractQuasiArray, dims)
301-
kr, jr = parentindices(V)
313+
function __sum(::MappedBasisLayouts, V::AbstractQuasiArray, dims)
314+
kr = basismap(V)
302315
@assert kr isa AbstractAffineQuasiVector
303316
sum(demap(V); dims=dims)/kr.A
304317
end

src/operators.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ end
8585
@simplify *(A::QuasiAdjoint{<:Any,<:DiracDelta}, B::AbstractQuasiVector) = B[parent(A).x]
8686
@simplify *(A::QuasiAdjoint{<:Any,<:DiracDelta}, B::AbstractQuasiMatrix) = B[parent(A).x,:]
8787

88+
show(io::IO, δ::DiracDelta) = print(io, "δ at $(δ.x) over $(axes(δ,1))")
89+
show(io::IO, ::MIME"text/plain", δ::DiracDelta) = show(io, δ)
90+
8891
#########
8992
# Derivative
9093
#########

0 commit comments

Comments
 (0)