Skip to content

Commit b880c37

Browse files
authored
Redesign grid to take in an n corresponding to the number of coefficients needed (#138)
* have grid work without subarray * Update test_chebyshev.jl * simplify grid overloading * v0.12 * arr -> szs * tests pass * minor changes * increase cov * mul vectors
1 parent 75bf1d2 commit b880c37

File tree

8 files changed

+88
-34
lines changed

8 files changed

+88
-34
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.11.2"
3+
version = "0.12"
4+
45

56
[deps]
67
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/ContinuumArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclu
2121
AbstractQuasiFill, UnionDomain, __sum, _cumsum, __cumsum, applylayout, _equals, layout_broadcasted, PolynomialLayout
2222
import InfiniteArrays: Infinity, InfAxes
2323

24-
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, ℵ₁, Inclusion, Basis, grid, plotgrid, affine, .., transform, expand, plan_transform
24+
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, ℵ₁, Inclusion, Basis, grid, plotgrid, affine, .., transform, expand, plan_transform, basis, coefficients
2525

2626

2727

src/bases/bases.jl

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::Basis) where Basis<:Abstrac
3030
sublayout(::WeightLayout, _) = WeightLayout()
3131
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{Map,AbstractVector}}) = MappedBasisLayout()
3232

33+
# copy with an Inclusion can not be materialized
34+
copy(V::SubQuasiArray{<:Any,N,<:Basis,<:Tuple{Inclusion,Vararg{Any}}, trfl}) where {N,trfl} = V
35+
3336

3437
## Weighted basis interface
3538
unweighted(P::BroadcastQuasiMatrix{<:Any,typeof(*),<:Tuple{AbstractQuasiVector,AbstractQuasiMatrix}}) = last(P.args)
@@ -148,12 +151,26 @@ copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)},<:Any,<:AbstractQua
148151

149152

150153
# expansion
151-
_grid(_, P) = error("Overload Grid")
154+
_grid(_, P, n...) = error("Overload Grid")
155+
156+
_grid(::MappedBasisLayout, P, n...) = invmap(parentindices(P)[1])[grid(demap(P), n...)]
157+
_grid(::SubBasisLayout, P::AbstractQuasiMatrix, n) = grid(parent(P), maximum(parentindices(P)[2][n]))
158+
_grid(::SubBasisLayout, P::AbstractQuasiMatrix) = grid(parent(P), maximum(parentindices(P)[2]))
159+
_grid(::WeightedBasisLayouts, P, n...) = grid(unweighted(P), n...)
152160

153-
_grid(::MappedBasisLayout, P) = invmap(parentindices(P)[1])[grid(demap(P))]
154-
_grid(::SubBasisLayout, P) = grid(parent(P))
155-
_grid(::WeightedBasisLayouts, P) = grid(unweighted(P))
156-
grid(P) = _grid(MemoryLayout(P), P)
161+
162+
"""
163+
grid(P, n...)
164+
165+
Creates a grid of points. if `n` is unspecified it will
166+
be sufficient number of points to determine `size(P,2)`
167+
coefficients. Otherwise its enough points to determine `n`
168+
coefficients.
169+
"""
170+
grid(P, n...) = _grid(MemoryLayout(P), P, n...)
171+
172+
173+
# values(f) =
157174

158175

159176
struct TransformFactorization{T,Grid,Plan} <: Factorization{T}
@@ -239,15 +256,18 @@ function *(P::InvPlan, X::AbstractArray)
239256
end
240257

241258

242-
function plan_grid_transform(L, arr, dims=1:ndims(arr))
259+
function plan_grid_transform(L, szs::NTuple{N,Int}, dims=1:N) where N
243260
p = grid(L)
244261
p, InvPlan(factorize(L[p,:]), dims)
245262
end
246263

247-
plan_transform(P, arr, dims...) = plan_grid_transform(P, arr, dims...)[2]
264+
plan_grid_transform(L, arr::AbstractArray{<:Any,N}, dims=1:N) where N =
265+
plan_grid_transform(L, size(arr), dims)
266+
267+
plan_transform(P, szs, dims...) = plan_grid_transform(P, szs, dims...)[2]
248268

249269
_factorize(::AbstractBasisLayout, L, dims...; kws...) =
250-
TransformFactorization(plan_grid_transform(L, Array{eltype(L)}(undef, size(L,2), dims...), 1)...)
270+
TransformFactorization(plan_grid_transform(L, (size(L,2), dims...), 1)...)
251271

252272

253273

@@ -273,7 +293,7 @@ _sub_factorize(::Tuple{Any,Int}, (kr,jr)::Tuple{Any,OneTo}, L, dims...; kws...)
273293

274294
# ∞-dimensional parents need to use transforms. For now we assume the size of the transform is equal to the size of the truncation
275295
_sub_factorize(::Tuple{Any,Any}, (kr,jr)::Tuple{Any,OneTo}, L, dims...; kws...) =
276-
TransformFactorization(plan_grid_transform(parent(L), Array{eltype(L)}(undef, last(jr), dims...), 1)...)
296+
TransformFactorization(plan_grid_transform(parent(L), (last(jr), dims...), 1)...)
277297

278298
# If jr is not OneTo we project
279299
_sub_factorize(::Tuple{Any,Any}, (kr,jr), L, dims...; kws...) =
@@ -377,6 +397,7 @@ applylayout(::Type{typeof(*)}, ::Lay, ::Union{PaddedLayout,AbstractStridedLayout
377397

378398
basis(v::ApplyQuasiArray{<:Any,N,typeof(*)}) where N = v.args[1]
379399
coefficients(v::ApplyQuasiArray{<:Any,N,typeof(*),<:Tuple{Any,Any}}) where N = v.args[2]
400+
coefficients(v::ApplyQuasiArray{<:Any,N,typeof(*),<:Tuple{Any,Any,Vararg{Any}}}) where N = ApplyArray(*, tail(v.args)...)
380401

381402

382403
function unweighted(lay::ExpansionLayout, a)

src/bases/splines.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ function getindex(B::HeavisideSpline{T}, x::Number, k::Int) where T
4747
return zero(T)
4848
end
4949

50-
grid(L::HeavisideSpline) = L.points[1:end-1] .+ diff(L.points)/2
51-
grid(L::LinearSpline) = L.points
50+
grid(L::HeavisideSpline, n...) = L.points[1:end-1] .+ diff(L.points)/2
51+
grid(L::LinearSpline, n...) = L.points
5252

5353
## Sub-bases
5454

src/plotting.jl

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,45 @@
11

2+
const MAX_PLOT_POINTS = 10_000 # above this rendering is too slow
23

3-
_mul_plotgrid(_, args) = grid(first(args))
4-
_mul_plotgrid(::Tuple{Any,PaddedLayout}, (P,c)) = plotgrid(P[:,colsupport(c)])
4+
5+
"""
6+
plotgrid(P, n...)
7+
8+
returns a grid of points suitable for plotting. This may include
9+
endpoints or singular points not included in `grid`. `n` specifies
10+
the number of coefficients.
11+
"""
12+
13+
plotgrid(P, n...) = _plotgrid(MemoryLayout(P), P, n...)
14+
_plotgrid(lay, P, n=size(P,2)) = grid(P, min(n,MAX_PLOT_POINTS))
15+
16+
_plotgrid(::WeightedBasisLayouts, wP, n...) = plotgrid(unweighted(wP), n...)
17+
_plotgrid(::MappedBasisLayout, P, n...) = invmap(parentindices(P)[1])[plotgrid(demap(P), n...)]
18+
_plotgrid(::SubBasisLayout, P::AbstractQuasiMatrix, n) = plotgrid(parent(P), maximum(parentindices(P)[2][n]))
19+
_plotgrid(::SubBasisLayout, P::AbstractQuasiMatrix) = plotgrid(parent(P), maximum(parentindices(P)[2]))
20+
21+
22+
_mul_plotgrid(_, args) = _plotgrid(UnknownLayout(), first(args))
23+
_mul_plotgrid(::Tuple{Any,PaddedLayout}, (P,c)) = plotgrid(P, maximum(colsupport(c)))
524

625
function _plotgrid(lay::ExpansionLayout, P)
726
args = arguments(lay,P)
827
_mul_plotgrid(map(MemoryLayout,args), args)
928
end
1029

11-
_plotgrid(_, P) = grid(P)
12-
13-
_plotgrid(::WeightedBasisLayouts, wP) = plotgrid(unweighted(wP))
14-
_plotgrid(::MappedBasisLayout, P) = invmap(parentindices(P)[1])[plotgrid(demap(P))]
15-
16-
plotgrid(g) = _plotgrid(MemoryLayout(g), g)
17-
1830
_split_svec(x) = (x,)
1931
_split_svec(x::AbstractArray{<:StaticVector{2}}) = (map(first,x), map(last,x))
2032

2133
plotvalues(g::AbstractQuasiVector, x) = g[x]
2234
plotvalues(g::AbstractQuasiMatrix, x) = g[x,:]
35+
plotvalues(g::AbstractQuasiArray) = plotvalues(g, plotgrid(g))
2336

24-
@recipe function f(g::AbstractQuasiArray)
37+
function plotgridvalues(g)
2538
x = plotgrid(g)
26-
tuple(_split_svec(x)..., plotvalues(g,x))
39+
x, plotvalues(g,x)
40+
end
41+
42+
@recipe function f(g::AbstractQuasiArray)
43+
x,v = plotgridvalues(g)
44+
tuple(_split_svec(x)..., v)
2745
end

test/runtests.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ContinuumArrays, QuasiArrays, IntervalSets, DomainSets, FillArrays, LinearAlgebra, BandedMatrices, InfiniteArrays, Test, Base64, RecipesBase
22
import ContinuumArrays: ℵ₁, materialize, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, ℵ₁,
33
MappedBasisLayout, AdjointMappedBasisLayout, MappedWeightedBasisLayout, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
4-
basis, invmap, Map, checkpoints, _plotgrid, mul
4+
basis, invmap, Map, checkpoints, _plotgrid, mul, plotvalues
55
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
66
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivStyle, MulStyle
77

@@ -89,9 +89,16 @@ include("test_basisconcat.jl")
8989
rep = RecipesBase.apply_recipe(Dict{Symbol, Any}(), L)
9090
@test rep[1].args == (L.points,L[L.points,:])
9191

92+
rep = RecipesBase.apply_recipe(Dict{Symbol, Any}(), L[:,1:3])
93+
@test rep[1].args == (L.points,L[L.points,1:3])
94+
95+
@test plotgrid(L[:,1:3],3) == grid(L[:,1:3]) == grid(L[:,1:3],3) == L.points
96+
97+
9298
u = L*randn(6)
9399
rep = RecipesBase.apply_recipe(Dict{Symbol, Any}(), u)
94100
@test rep[1].args == (L.points,u[L.points])
101+
@test plotvalues(u) == u[plotgrid(u)]
95102

96103
@testset "padded" begin
97104
u = L * Vcat(rand(3), Zeros(3))

test/test_chebyshev.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
using ContinuumArrays, LinearAlgebra, FastTransforms, QuasiArrays, Test
2-
import ContinuumArrays: Basis, Weight, Map, LazyQuasiArrayStyle, TransformFactorization, ExpansionLayout
1+
using ContinuumArrays, LinearAlgebra, FastTransforms, QuasiArrays, ArrayLayouts, Test
2+
import ContinuumArrays: Basis, Weight, Map, LazyQuasiArrayStyle, TransformFactorization,
3+
ExpansionLayout, checkpoints, MappedBasisLayout, MappedWeightedBasisLayout,
4+
SubWeightedBasisLayout, WeightedBasisLayout, WeightLayout
35

46
"""
57
This is a simple implementation of Chebyshev for testing. Use ClassicalOrthogonalPolynomials
6-
for the real implementation.
8+
for the real implementation.
79
"""
810
struct Chebyshev <: Basis{Float64}
911
n::Int
@@ -14,17 +16,14 @@ struct ChebyshevWeight <: Weight{Float64} end
1416
Base.:(==)(::Chebyshev, ::Chebyshev) = true
1517
Base.:(==)(::ChebyshevWeight, ::ChebyshevWeight) = true
1618
Base.axes(T::Chebyshev) = (Inclusion(-1..1), Base.OneTo(T.n))
17-
ContinuumArrays.grid(T::Chebyshev) = chebyshevpoints(Float64, T.n, Val(1))
19+
ContinuumArrays.grid(T::Chebyshev, n...) = chebyshevpoints(Float64, T.n, Val(1))
1820
Base.axes(T::ChebyshevWeight) = (Inclusion(-1..1),)
1921

2022
Base.getindex(::Chebyshev, x::Float64, n::Int) = cos((n-1)*acos(x))
2123
Base.getindex(::ChebyshevWeight, x::Float64) = 1/sqrt(1-x^2)
2224
Base.getindex(w::ChebyshevWeight, ::Inclusion) = w # TODO: make automatic
2325

24-
LinearAlgebra.factorize(L::Chebyshev) =
25-
TransformFactorization(grid(L), plan_chebyshevtransform(Array{Float64}(undef, size(L,2))))
26-
LinearAlgebra.factorize(L::Chebyshev, n) =
27-
TransformFactorization(grid(L), plan_chebyshevtransform(Array{Float64}(undef, size(L,2),n),1))
26+
ContinuumArrays.plan_grid_transform(L::Chebyshev, szs::NTuple{N,Int}, dims=1:N) where N = grid(L), plan_chebyshevtransform(Array{eltype(L)}(undef, szs...), dims)
2827

2928
# This is wrong but just for tests
3029
QuasiArrays.layout_broadcasted(::Tuple{ExpansionLayout,Any}, ::typeof(*), a::ApplyQuasiVector{<:Any,typeof(*),<:Tuple{Chebyshev,Any}}, b::Chebyshev) = b * Matrix(I, 5, 5)

test/test_splines.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ContinuumArrays, LinearAlgebra, Base64, FillArrays, QuasiArrays, BandedMatrices, Test
22
using QuasiArrays: ApplyQuasiArray, ApplyStyle, MemoryLayout, mul, MulQuasiMatrix, Vec
33
import LazyArrays: MulStyle, LdivStyle, arguments, applied, apply
4-
import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout, SubBasisLayout, AdjointMappedBasisLayout, MappedBasisLayout, coefficients
4+
import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout, SubBasisLayout, AdjointMappedBasisLayout, MappedBasisLayout
55

66
@testset "Splines" begin
77
@testset "HeavisideSpline" begin
@@ -28,6 +28,8 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
2828
@test MemoryLayout(typeof(H)) == BasisLayout()
2929
@test ApplyStyle(*, typeof(H), typeof([1,2])) isa MulStyle
3030

31+
@test copy(H[:,1:2]) == H[:,1:2]
32+
3133
f = H*[1,2]
3234
@test f isa ApplyQuasiArray
3335
@test axes(f) == (Inclusion(1.0..3.0),)
@@ -495,4 +497,10 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
495497
end
496498
@test PX X
497499
end
500+
501+
@testset "Mul coefficients" begin
502+
L = LinearSpline(0:5)
503+
u = ApplyQuasiArray(*, L, randn(6,5), randn(5))
504+
@test coefficients(u) L \ u
505+
end
498506
end

0 commit comments

Comments
 (0)