Skip to content

Commit 05df318

Browse files
authored
Support demap for Expansions (#98)
* Support demap for Expansions * Update runtests.jl * Update bases.jl * add tests * quasimatrix ldiv * increase coverage
1 parent a27fcb3 commit 05df318

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.8.0"
3+
version = "0.8.1"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/bases/bases.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ function copy(P::Ldiv{<:MappedBasisLayouts,<:AbstractLazyLayout})
8989
A,B = P.A, P.B
9090
demap(A) \ B[invmap(basismap(A))]
9191
end
92-
copy(P::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)}}) = copy(Ldiv{UnknownLayout,ApplyLayout{typeof(*)}}(P.A,P.B))
92+
copy(L::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)}}) = copy(Ldiv{UnknownLayout,ApplyLayout{typeof(*)}}(L.A,L.B))
93+
copy(L::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = transform_ldiv(L.A, L.B)
9394

9495
@inline copy(L::Ldiv{<:AbstractBasisLayout,<:SubBasisLayouts}) = apply(\, L.A, ApplyQuasiArray(L.B))
9596
@inline function copy(L::Ldiv{<:SubBasisLayouts,<:AbstractBasisLayout})
@@ -110,7 +111,7 @@ end
110111
a,b = arguments(B)
111112
@assert a isa AbstractQuasiVector # Only works for vec .* mat
112113
ab = (A * (A \ a)) .* b # broadcasted should be overloaded
113-
MemoryLayout(ab) isa BroadcastLayout && error("Overload broadcasted(_, ::typeof(*), ::$(typeof(ab.args[1])), ::$(typeof(b)))")
114+
MemoryLayout(ab) isa BroadcastLayout && return transform_ldiv(A, ab)
114115
A \ ab
115116
end
116117

@@ -119,6 +120,10 @@ _broadcast_mul_ldiv(_, A, B) = copy(Ldiv{typeof(MemoryLayout(A)),UnknownLayout}(
119120
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
120121
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
121122

123+
# ambiguity
124+
copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
125+
copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
126+
122127

123128
# expansion
124129
_grid(_, P) = error("Overload Grid")
@@ -172,6 +177,7 @@ struct ProjectionFactorization{T, FAC<:Factorization{T}, INDS} <: Factorization{
172177
end
173178

174179
\(a::ProjectionFactorization, b::AbstractQuasiVector) = (a.F \ b)[a.inds]
180+
\(a::ProjectionFactorization, b::AbstractQuasiMatrix) = (a.F \ b)[a.inds,:]
175181
\(a::ProjectionFactorization, b::AbstractVector) = (a.F \ b)[a.inds]
176182

177183
_factorize(::SubBasisLayout, L) = ProjectionFactorization(factorize(parent(L)), parentindices(L)[2])
@@ -183,6 +189,8 @@ end
183189

184190
\(a::MappedFactorization, b::AbstractQuasiVector) = a.F \ view(b, a.map)
185191
\(a::MappedFactorization, b::AbstractVector) = a.F \ b
192+
\(a::MappedFactorization, b::AbstractQuasiMatrix) = a.F \ view(b, a.map, :)
193+
186194

187195
function invmap end
188196

@@ -268,6 +276,9 @@ function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(*), a::Expansion, f::Exp
268276
end
269277

270278

279+
_function_mult_broadcasted(_, _, a, B) = Base.Broadcast.Broadcasted{LazyQuasiArrayStyle{2}}(*, (a, B))
280+
broadcasted(::LazyQuasiArrayStyle{2}, ::typeof(*), a::Expansion, B::AbstractQuasiMatrix) = _function_mult_broadcasted(MemoryLayout(a), MemoryLayout(B), a, B)
281+
271282
@eval function ==(f::Expansion, g::Expansion)
272283
S,c = arguments(f)
273284
T,d = arguments(g)
@@ -351,6 +362,11 @@ function demap(V::SubQuasiArray{<:Any,2})
351362
kr, jr = parentindices(V)
352363
demap(parent(V)[kr,:])[:,jr]
353364
end
365+
function demap(wB::ApplyQuasiArray{<:Any,N,typeof(*)}) where N
366+
a = arguments(wB)
367+
*(demap(first(a)), tail(a)...)
368+
end
369+
354370

355371
basismap(x::SubQuasiArray) = parentindices(x)[1]
356372
basismap(x::BroadcastQuasiArray) = basismap(x.args[1])

test/runtests.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,8 @@ end
9797
@test f[2.1] 2
9898

9999
@test @inferred(H'H) == @inferred(materialize(applied(*,H',H))) == Eye(2)
100-
if VERSION < v"1.6-"
101-
@test summary(f) == "(Spline{0,Float64,Array{$Int,1}}) * (2-element Array{$Int,1})"
102-
@test stringmime("text/plain", f) == "Spline{0,Float64,Array{$Int,1}} * [1, 2]"
103-
else
104-
@test summary(f) == "(HeavisideSpline{Float64, Vector{$Int}}) * (2-element Vector{$Int})"
105-
@test stringmime("text/plain", f) == "HeavisideSpline{Float64, Vector{$Int}} * [1, 2]"
106-
end
100+
@test summary(f) == "(HeavisideSpline{Float64, Vector{$Int}}) * (2-element Vector{$Int})"
101+
@test stringmime("text/plain", f) == "HeavisideSpline{Float64, Vector{$Int}} * [1, 2]"
107102
end
108103

109104
@testset "LinearSpline" begin
@@ -442,6 +437,13 @@ end
442437
@testset "vec demap" begin
443438
@test L[y,:] \ exp.(axes(L,1))[y] L[y,:] \ exp.(y) factorize(L[y,:]) \ exp.(y)
444439
@test ContinuumArrays.demap(view(axes(L,1),y)) == axes(L,1)
440+
441+
@test L[y,:] \ (y .* exp.(y)) L[y,:] \ BroadcastQuasiVector(y -> y*exp(y), y)
442+
@test L[y,:] \ (y .* L[y,1:3]) [L[y,:]\(y .* L[y,1]) L[y,:]\(y .* L[y,2]) L[y,:]\(y .* L[y,3])]
443+
444+
c = randn(size(L,2))
445+
@test L[y,:] \ (L[y,:] * c) c
446+
@test ContinuumArrays.demap(L[y,:] * c) == L*c
445447
end
446448
end
447449

@@ -457,11 +459,7 @@ end
457459
H = HeavisideSpline([1,2,3,6])
458460
B = H[5x .+ 1,:]
459461
u = H * [1,2,3]
460-
if VERSION < v"1.6-"
461-
@test stringmime("text/plain", B) == "Spline{0,Float64,Array{$Int,1}} affine mapped to 0..1"
462-
else
463-
@test stringmime("text/plain", B) == "HeavisideSpline{Float64, Vector{$Int}} affine mapped to 0..1"
464-
end
462+
@test stringmime("text/plain", B) == "HeavisideSpline{Float64, Vector{$Int}} affine mapped to 0..1"
465463
end
466464
end
467465

0 commit comments

Comments
 (0)