Skip to content

Commit 93141e2

Browse files
authored
Add plan_transform (#106)
* Add plan_transform * Work on plan_transform * Update ClassicalOrthogonalPolynomials.jl * Move cardinality * Add MulPlan * Update interlace.jl * plan_grid_transform * Update ci.yml * add transform tests * Update test_normalized.jl
1 parent 9305fc3 commit 93141e2

File tree

7 files changed

+138
-46
lines changed

7 files changed

+138
-46
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
matrix:
1212
version:
1313
- '1.7'
14-
- '^1.8.0-0'
14+
- '1'
1515
os:
1616
- ubuntu-latest
1717
- macOS-latest

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClassicalOrthogonalPolynomials"
22
uuid = "b30e2e7b-c4ee-47da-9d5f-2c5c27239acd"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -29,11 +29,11 @@ ArrayLayouts = "0.8"
2929
BandedMatrices = "0.17"
3030
BlockArrays = "0.16.9"
3131
BlockBandedMatrices = "0.11.6"
32-
ContinuumArrays = "0.10"
32+
ContinuumArrays = "0.11"
3333
DomainSets = "0.5.6"
3434
FFTW = "1.1"
3535
FastGaussQuadrature = "0.4.3"
36-
FastTransforms = "0.13, 0.14"
36+
FastTransforms = "0.14.4"
3737
FillArrays = "0.13"
3838
HypergeometricFunctions = "0.3.4"
3939
InfiniteArrays = "0.12.3"

src/ClassicalOrthogonalPolynomials.jl

Lines changed: 69 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import InfiniteLinearAlgebra: chop!, chop, choplength, compatible_resize!
3636
import ContinuumArrays: Basis, Weight, basis, @simplify, Identity, AbstractAffineQuasiVector, ProjectionFactorization,
3737
inbounds_getindex, grid, plotgrid, transform_ldiv, TransformFactorization, QInfAxes, broadcastbasis, ExpansionLayout, basismap,
3838
AffineQuasiVector, AffineMap, WeightLayout, AbstractWeightedBasisLayout, WeightedBasisLayout, WeightedBasisLayouts, demap, AbstractBasisLayout, BasisLayout,
39-
checkpoints, weight, unweighted, MappedBasisLayouts, __sum, invmap, plan_ldiv, layout_broadcasted, MappedBasisLayout, SubBasisLayout, _broadcastbasis
39+
checkpoints, weight, unweighted, MappedBasisLayouts, __sum, invmap, plan_ldiv, layout_broadcasted, MappedBasisLayout, SubBasisLayout, _broadcastbasis,
40+
plan_transform, plan_grid_transform
4041
import FastTransforms: Λ, forwardrecurrence, forwardrecurrence!, _forwardrecurrence!, clenshaw, clenshaw!,
4142
_forwardrecurrence_next, _clenshaw_next, check_clenshaw_recurrences, ChebyshevGrid, chebyshevpoints, Plan
4243

@@ -52,20 +53,13 @@ export OrthogonalPolynomial, Normalized, orthonormalpolynomial, LanczosPolynomia
5253
∞, Derivative, .., Inclusion,
5354
chebyshevt, chebyshevu, legendre, jacobi, ultraspherical,
5455
legendrep, jacobip, ultrasphericalc, laguerrel,hermiteh, normalizedjacobip,
55-
jacobimatrix, jacobiweight, legendreweight, chebyshevtweight, chebyshevuweight, Weighted, PiecewiseInterlace
56+
jacobimatrix, jacobiweight, legendreweight, chebyshevtweight, chebyshevuweight, Weighted, PiecewiseInterlace, plan_transform
5657

5758

5859
import Base: oneto
5960

6061

6162
include("interlace.jl")
62-
63-
64-
cardinality(::FullSpace{<:AbstractFloat}) = ℵ₁
65-
cardinality(::EuclideanDomain) = ℵ₁
66-
cardinality(::Union{DomainSets.RealNumbers,DomainSets.ComplexNumbers}) = ℵ₁
67-
cardinality(::Union{DomainSets.Integers,DomainSets.Rationals,DomainSets.NaturalNumbers}) = ℵ₀
68-
6963
include("standardchop.jl")
7064
include("adaptivetransform.jl")
7165

@@ -95,9 +89,9 @@ _equals(::MappedOPLayout, ::MappedOPLayout, P, Q) = demap(P) == demap(Q) && basi
9589
_equals(::MappedOPLayout, ::MappedBasisLayouts, P, Q) = demap(P) == demap(Q) && basismap(P) == basismap(Q)
9690
_equals(::MappedBasisLayouts, ::MappedOPLayout, P, Q) = demap(P) == demap(Q) && basismap(P) == basismap(Q)
9791

98-
_broadcastbasis(::typeof(+), ::MappedOPLayout, ::MappedOPLayout, P, Q) where {L,M} = _broadcastbasis(+, MappedBasisLayout(), MappedBasisLayout(), P, Q)
99-
_broadcastbasis(::typeof(+), ::MappedOPLayout, M::MappedBasisLayout, P, Q) where L = _broadcastbasis(+, MappedBasisLayout(), M, P, Q)
100-
_broadcastbasis(::typeof(+), L::MappedBasisLayout, ::MappedOPLayout, P, Q) where M = _broadcastbasis(+, L, MappedBasisLayout(), P, Q)
92+
_broadcastbasis(::typeof(+), ::MappedOPLayout, ::MappedOPLayout, P, Q) = _broadcastbasis(+, MappedBasisLayout(), MappedBasisLayout(), P, Q)
93+
_broadcastbasis(::typeof(+), ::MappedOPLayout, M::MappedBasisLayout, P, Q) = _broadcastbasis(+, MappedBasisLayout(), M, P, Q)
94+
_broadcastbasis(::typeof(+), L::MappedBasisLayout, ::MappedOPLayout, P, Q) = _broadcastbasis(+, L, MappedBasisLayout(), P, Q)
10195
__sum(::MappedOPLayout, A, dims) = __sum(MappedBasisLayout(), A, dims)
10296

10397
# demap to avoid Golub-Welsch fallback
@@ -231,11 +225,6 @@ function recurrencecoefficients(C::SubQuasiArray{T,2,<:Any,<:Tuple{AbstractAffin
231225
A * kr.A, A*kr.b + B, C
232226
end
233227

234-
235-
_vec(a) = vec(a)
236-
_vec(a::InfiniteArrays.ReshapedArray) = _vec(parent(a))
237-
_vec(a::Adjoint{<:Any,<:AbstractVector}) = a'
238-
239228
include("clenshaw.jl")
240229
include("ratios.jl")
241230
include("normalized.jl")
@@ -270,24 +259,76 @@ function golubwelsch(V::SubQuasiArray)
270259
x,w
271260
end
272261

273-
function factorize(L::SubQuasiArray{T,2,<:Normalized,<:Tuple{Inclusion,OneTo}}, dims...; kws...) where T
274-
x,w = golubwelsch(L)
275-
TransformFactorization(x, L[x,:]'*Diagonal(w))
262+
"""
263+
MulPlan(matrix, dims)
264+
265+
Takes a matrix and supports it applied to different dimensions.
266+
"""
267+
struct MulPlan{T, Fact, Dims} # <: Plan{T} We don't depend on AbstractFFTs
268+
matrix::Fact
269+
dims::Dims
276270
end
277271

272+
MulPlan(fact, dims) = MulPlan{eltype(fact), typeof(fact), typeof(dims)}(fact, dims)
273+
274+
function *(P::MulPlan{<:Any,<:Any,Int}, x::AbstractVector)
275+
@assert P.dims == 1
276+
P.matrix * x
277+
end
278278

279-
function factorize(L::SubQuasiArray{T,2,<:OrthogonalPolynomial,<:Tuple{Inclusion,OneTo}}, dims...; kws...) where T
280-
Q = Normalized(parent(L))[parentindices(L)...]
281-
D = L \ Q
282-
F = factorize(Q, dims...; kws...)
283-
TransformFactorization(F.grid, D*F.plan)
279+
function *(P::MulPlan{<:Any,<:Any,Int}, X::AbstractMatrix)
280+
if P.dims == 1
281+
P.matrix * X
282+
else
283+
@assert P.dims == 2
284+
permutedims(P.matrix * permutedims(X))
285+
end
284286
end
285287

286-
function factorize(L::SubQuasiArray{T,2,<:OrthogonalPolynomial,<:Tuple{<:Inclusion,<:AbstractUnitRange}}, dims...; kws...) where T
287-
_,jr = parentindices(L)
288-
ProjectionFactorization(factorize(parent(L)[:,oneto(maximum(jr))], dims...; kws...), jr)
288+
function *(P::MulPlan{<:Any,<:Any,Int}, X::AbstractArray{<:Any,3})
289+
Y = similar(X)
290+
if P.dims == 1
291+
for j in axes(X,3)
292+
Y[:,:,j] = P.matrix * X[:,:,j]
293+
end
294+
elseif P.dims == 2
295+
for k in axes(X,1)
296+
Y[k,:,:] = P.matrix * X[k,:,:]
297+
end
298+
else
299+
@assert P.dims == 3
300+
for k in axes(X,1), j in axes(X,2)
301+
Y[k,j,:] = P.matrix * X[k,j,:]
302+
end
303+
end
304+
Y
289305
end
290306

307+
function *(P::MulPlan, X::AbstractArray)
308+
for d in P.dims
309+
X = MulPlan(P.matrix, d) * X
310+
end
311+
X
312+
end
313+
314+
*(A::AbstractMatrix, P::MulPlan) = MulPlan(A*P.matrix, P.dims)
315+
316+
317+
function plan_grid_transform(Q::Normalized, arr, dims=1:ndims(arr))
318+
L = Q[:,OneTo(size(arr,1))]
319+
x,w = golubwelsch(L)
320+
x, MulPlan(L[x,:]'*Diagonal(w), dims)
321+
end
322+
323+
function plan_grid_transform(P::OrthogonalPolynomial, arr, dims...)
324+
Q = Normalized(P)
325+
x, A = plan_grid_transform(Q, arr, dims...)
326+
n = size(arr,1)
327+
D = (P \ Q)[1:n, 1:n]
328+
x, D * A
329+
end
330+
331+
291332
function \(A::SubQuasiArray{<:Any,2,<:OrthogonalPolynomial}, B::SubQuasiArray{<:Any,2,<:OrthogonalPolynomial})
292333
axes(A,1) == axes(B,1) || throw(DimensionMismatch())
293334
_,jA = parentindices(A)

src/classical/chebyshev.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,16 @@ Jacobi(C::ChebyshevU{T}) where T = Jacobi(one(T)/2,one(T)/2)
103103
#######
104104

105105

106-
factorize(L::SubQuasiArray{T,2,<:ChebyshevT,<:Tuple{<:Inclusion,<:OneTo}}) where T =
107-
TransformFactorization(grid(L), plan_chebyshevtransform(Array{T}(undef, size(L,2))))
108-
109-
factorize(L::SubQuasiArray{T,2,<:ChebyshevT,<:Tuple{<:Inclusion,<:OneTo}}, m) where T =
110-
TransformFactorization(grid(L), plan_chebyshevtransform(Array{T}(undef, size(L,2), m),1))
111-
112-
# TODO: extend plan_chebyshevutransform
113-
factorize(L::SubQuasiArray{T,2,<:ChebyshevU,<:Tuple{<:Inclusion,<:OneTo}}) where T<:FastTransforms.fftwNumber =
114-
TransformFactorization(grid(L), plan_chebyshevutransform(Array{T}(undef, size(L,2))))
106+
function plan_grid_transform(T::ChebyshevT, arr, dims...)
107+
n = size(arr,1)
108+
x = grid(T[:,oneto(n)])
109+
x, plan_chebyshevtransform(arr, dims...)
110+
end
111+
function plan_grid_transform(U::ChebyshevU{<:FastTransforms.fftwNumber}, arr, dims...)
112+
n = size(arr,1)
113+
x = grid(U[:,oneto(n)])
114+
x, plan_chebyshevutransform(arr, dims...)
115+
end
115116

116117

117118
########

src/interlace.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function \(F::SetindexFactorization{T,<:AbstractFill}, v::AbstractQuasiVector) w
241241
for (k,x) in enumerate(F̃.grid)
242242
data[k,:] = v[x]
243243
end
244-
blockvec(Matrix(transpose(F̃ \ data))) # call Matrix to avoid ReshapedArray
244+
blockvec(Matrix(transpose(F̃.plan * data))) # call Matrix to avoid ReshapedArray
245245
end
246246

247247

test/test_chebyshev.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ClassicalOrthogonalPolynomials, QuasiArrays, ContinuumArrays, BandedMatrices, LazyArrays,
1+
using ClassicalOrthogonalPolynomials, QuasiArrays, ContinuumArrays, BandedMatrices, LazyArrays,
22
FastTransforms, ArrayLayouts, Test, FillArrays, Base64, BlockArrays, LazyBandedMatrices, ForwardDiff
33
import ClassicalOrthogonalPolynomials: Clenshaw, recurrencecoefficients, clenshaw, paddeddata, jacobimatrix, oneto, Weighted, MappedOPLayout
44
import LazyArrays: ApplyStyle
@@ -120,6 +120,13 @@ import ContinuumArrays: MappedWeightedBasisLayout, Map, WeightedBasisLayout
120120
@time U = T / T \ cos.(x .* (0:1000)');
121121
@test U[0.1,:] cos.(0.1 * (0:1000))
122122
@test U[[0.1,0.2],:] [cos.(0.1 * (0:1000)'); cos.(0.2 * (0:1000)')]
123+
124+
# support tensors but for grids
125+
X = randn(150, 2, 2)
126+
F = plan_transform(T, X, 1)
127+
F_m = plan_transform(T, X[:,:,1], 1)
128+
@test (F * X)[:,:,1] F_m * X[:,:,1]
129+
@test (F * X)[:,:,2] F_m * X[:,:,2]
123130
end
124131
end
125132

@@ -183,11 +190,11 @@ import ContinuumArrays: MappedWeightedBasisLayout, Map, WeightedBasisLayout
183190
@test sum(wT; dims=1)[:,1:10] zeros(1,9)]
184191
@test sum(wT[:,1]) π
185192
@test iszero(sum(wT[:,2]))
186-
193+
187194
@test (wT \ wU)[1:10,1:10] inv(wU \ wT)[1:10,1:10]
188195
@test_skip (wU \ WT)[1,1] == 2
189196
end
190-
197+
191198
@testset "Derivative" begin
192199
x = axes(wT,1)
193200
D = Derivative(x)
@@ -445,7 +452,7 @@ import ContinuumArrays: MappedWeightedBasisLayout, Map, WeightedBasisLayout
445452
T = ChebyshevT(); U = ChebyshevU()
446453
D = Derivative(axes(T,1))
447454
V = view(T,:,[1,3,4])
448-
@test (U\(D*V))[1:5,:] == (U \ (V'D')')[1:5,:] == (U\(D*T))[1:5,[1,3,4]]
455+
@test (U\(D*V))[1:5,:] == (U \ (V'D')')[1:5,:] == (U\(D*T))[1:5,[1,3,4]]
449456
end
450457

451458
@testset "plot" begin

test/test_normalized.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,47 @@ import ContinuumArrays: MappedWeightedBasisLayout
209209
@test grid(Q[:,1:5]) == grid(Q[:,collect(1:5)]) == grid(P[:,1:5])
210210
@test plotgrid(Q[:,1:5]) == plotgrid(Q[:,collect(1:5)]) == plotgrid(P[:,1:5])
211211
end
212+
213+
@testset "Transform" begin
214+
Q = Normalized(Hermite())
215+
n = 20
216+
Qₙ = Q[:,Base.OneTo(n)]
217+
x = axes(Q,1)
218+
g = grid(Qₙ)
219+
v = exp.(g)
220+
P = plan_transform(Q, v)
221+
@test P * v Qₙ[g,:] \ exp.(g) transform(Qₙ, exp)
222+
223+
V = cos.(g .* (1:3)')
224+
P = plan_transform(Q, V, 1)
225+
@test P * V Qₙ \ cos.(x .* (1:3)')
226+
227+
X = randn(n, n)
228+
P₂ = plan_transform(Q, X, 2)
229+
230+
P = plan_transform(Q, X)
231+
232+
PX = P * X
233+
for k = 1:n
234+
X[:, k] = Qₙ[g,:] \ X[:, k]
235+
end
236+
for k = 1:n
237+
X[k, :] = Qₙ[g,:] \ X[k, :]
238+
end
239+
@test PX X
240+
241+
X = randn(n, n, n)
242+
P = plan_transform(Q, X)
243+
PX = P * X
244+
for k = 1:n, j = 1:n
245+
X[:, k, j] = Qₙ[g,:] \ X[:, k, j]
246+
end
247+
for k = 1:n, j = 1:n
248+
X[k, :, j] = Qₙ[g,:] \ X[k, :, j]
249+
end
250+
for k = 1:n, j = 1:n
251+
X[k, j, :] = Qₙ[g,:] \ X[k, j, :]
252+
end
253+
@test PX X
254+
end
212255
end

0 commit comments

Comments
 (0)