Skip to content

Commit d44eab0

Browse files
authored
Mapped transforms (#140)
* Mapped transforms * Move Mul here, support inv plans
1 parent 13b6d27 commit d44eab0

File tree

5 files changed

+173
-86
lines changed

5 files changed

+173
-86
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.12.2"
4-
3+
version = "0.12.3"
54

65
[deps]
6+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
88
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -19,6 +19,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1919
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2020

2121
[compat]
22+
AbstractFFTs = "1"
2223
ArrayLayouts = "0.7.7, 0.8"
2324
BandedMatrices = "0.16, 0.17"
2425
BlockArrays = "0.16"

src/ContinuumArrays.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupp
99
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex, UnknownLayout,
1010
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles, applylayout,
1111
simplifiable, _simplify, AbstractLazyLayout, PaddedLayout
12-
import LinearAlgebra: pinv, dot, norm2, ldiv!, mul!
12+
import LinearAlgebra: pinv, inv, dot, norm2, ldiv!, mul!
1313
import BandedMatrices: AbstractBandedLayout, _BandedMatrix
1414
import BlockArrays: block, blockindex, unblock, blockedrange, _BlockedUnitRange, _BlockArray
1515
import FillArrays: AbstractFill, getindex_value, SquareEye
@@ -20,6 +20,7 @@ import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclu
2020
LazyQuasiArray, LazyQuasiVector, LazyQuasiMatrix, LazyLayout, LazyQuasiArrayStyle, _factorize,
2121
AbstractQuasiFill, UnionDomain, __sum, _cumsum, __cumsum, applylayout, _equals, layout_broadcasted, PolynomialLayout
2222
import InfiniteArrays: Infinity, InfAxes
23+
import AbstractFFTs: Plan
2324

2425
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, ℵ₁, Inclusion, Basis, grid, plotgrid, affine, .., transform, expand, plan_transform, basis, coefficients
2526

@@ -91,6 +92,7 @@ checkpoints(A::AbstractQuasiMatrix) = checkpoints(axes(A,1))
9192

9293

9394
include("operators.jl")
95+
include("plans.jl")
9496
include("bases/bases.jl")
9597

9698
include("plotting.jl")

src/bases/bases.jl

Lines changed: 7 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -173,93 +173,18 @@ grid(P, n...) = _grid(MemoryLayout(P), P, n...)
173173
# values(f) =
174174

175175

176-
struct TransformFactorization{T,Grid,Plan} <: Factorization{T}
177-
grid::Grid
178-
plan::Plan
179-
end
180-
181-
TransformFactorization{T}(grid, plan) where T = TransformFactorization{T,typeof(grid),typeof(plan)}(grid, plan)
182-
183-
"""
184-
TransformFactorization(grid, plan)
185-
186-
associates a planned transform with a grid. That is, if `F` is a `TransformFactorization`, then
187-
`F \\ f` is equivalent to `F.plan * f[F.grid]`.
188-
"""
189-
TransformFactorization(grid, plan) = TransformFactorization{promote_type(eltype(eltype(grid)),eltype(plan))}(grid, plan)
190-
191-
192-
193-
grid(T::TransformFactorization) = T.grid
194-
function size(T::TransformFactorization, k)
195-
@assert k == 2 # TODO: make consistent
196-
size(T.plan,1)
197-
end
198-
199-
200-
\(a::TransformFactorization, b::AbstractQuasiVector) = a.plan * convert(Array, b[a.grid])
201-
\(a::TransformFactorization, b::AbstractQuasiMatrix) = a.plan * convert(Array, b[a.grid,:])
202-
203-
"""
204-
InvPlan(factorization, dims)
205-
206-
Takes a factorization and supports it applied to different dimensions.
207-
"""
208-
struct InvPlan{T, Fact, Dims} # <: Plan{T} We don't depend on AbstractFFTs
209-
factorization::Fact
210-
dims::Dims
211-
end
212176

213-
InvPlan(fact, dims) = InvPlan{eltype(fact), typeof(fact), typeof(dims)}(fact, dims)
214-
215-
size(F::InvPlan, k...) = size(F.factorization, k...)
216-
217-
218-
function *(P::InvPlan{<:Any,<:Any,Int}, x::AbstractVector)
219-
@assert P.dims == 1
220-
P.factorization \ x
221-
end
222-
223-
function *(P::InvPlan{<:Any,<:Any,Int}, X::AbstractMatrix)
224-
if P.dims == 1
225-
P.factorization \ X
226-
else
227-
@assert P.dims == 2
228-
permutedims(P.factorization \ permutedims(X))
229-
end
230-
end
231-
232-
function *(P::InvPlan{<:Any,<:Any,Int}, X::AbstractArray{<:Any,3})
233-
Y = similar(X)
234-
if P.dims == 1
235-
for j in axes(X,3)
236-
Y[:,:,j] = P.factorization \ X[:,:,j]
237-
end
238-
elseif P.dims == 2
239-
for k in axes(X,1)
240-
Y[k,:,:] = P.factorization \ X[k,:,:]
241-
end
242-
else
243-
@assert P.dims == 3
244-
for k in axes(X,1), j in axes(X,2)
245-
Y[k,j,:] = P.factorization \ X[k,j,:]
246-
end
247-
end
248-
Y
177+
function plan_grid_transform(lay, L, szs::NTuple{N,Int}, dims=1:N) where N
178+
p = grid(L)
179+
p, InvPlan(factorize(L[p,:]), dims)
249180
end
250181

251-
function *(P::InvPlan, X::AbstractArray)
252-
for d in P.dims
253-
X = InvPlan(P.factorization, d) * X
254-
end
255-
X
182+
function plan_grid_transform(::MappedBasisLayout, L, szs::NTuple{N,Int}, dims=1:N) where N
183+
x,F = plan_grid_transform(demap(L), szs, dims)
184+
invmap(parentindices(L)[1])[x], F
256185
end
257186

258-
259-
function plan_grid_transform(L, szs::NTuple{N,Int}, dims=1:N) where N
260-
p = grid(L)
261-
p, InvPlan(factorize(L[p,:]), dims)
262-
end
187+
plan_grid_transform(L, szs::NTuple{N,Int}, dims=1:N) where N = plan_grid_transform(MemoryLayout(L), L, szs, dims)
263188

264189
plan_grid_transform(L, arr::AbstractArray{<:Any,N}, dims=1:N) where N =
265190
plan_grid_transform(L, size(arr), dims)

src/plans.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
2+
struct TransformFactorization{T,Grid,Plan} <: Factorization{T}
3+
grid::Grid
4+
plan::Plan
5+
end
6+
7+
TransformFactorization{T}(grid, plan) where T = TransformFactorization{T,typeof(grid),typeof(plan)}(grid, plan)
8+
9+
"""
10+
TransformFactorization(grid, plan)
11+
12+
associates a planned transform with a grid. That is, if `F` is a `TransformFactorization`, then
13+
`F \\ f` is equivalent to `F.plan * f[F.grid]`.
14+
"""
15+
TransformFactorization(grid, plan) = TransformFactorization{promote_type(eltype(eltype(grid)),eltype(plan))}(grid, plan)
16+
17+
18+
19+
grid(T::TransformFactorization) = T.grid
20+
function size(T::TransformFactorization, k)
21+
@assert k == 2 # TODO: make consistent
22+
size(T.plan,1)
23+
end
24+
25+
26+
\(a::TransformFactorization, b::AbstractQuasiVector) = a.plan * convert(Array, b[a.grid])
27+
\(a::TransformFactorization, b::AbstractQuasiMatrix) = a.plan * convert(Array, b[a.grid,:])
28+
29+
"""
30+
InvPlan(factorization, dims)
31+
32+
Takes a factorization and supports it applied to different dimensions.
33+
"""
34+
struct InvPlan{T, Fact, Dims} <: Plan{T}
35+
factorization::Fact
36+
dims::Dims
37+
end
38+
39+
InvPlan(fact, dims) = InvPlan{eltype(fact), typeof(fact), typeof(dims)}(fact, dims)
40+
41+
size(F::InvPlan, k...) = size(F.factorization, k...)
42+
43+
44+
function *(P::InvPlan{<:Any,<:Any,Int}, x::AbstractVector)
45+
@assert P.dims == 1
46+
P.factorization \ x
47+
end
48+
49+
function *(P::InvPlan{<:Any,<:Any,Int}, X::AbstractMatrix)
50+
if P.dims == 1
51+
P.factorization \ X
52+
else
53+
@assert P.dims == 2
54+
permutedims(P.factorization \ permutedims(X))
55+
end
56+
end
57+
58+
function *(P::InvPlan{<:Any,<:Any,Int}, X::AbstractArray{<:Any,3})
59+
Y = similar(X)
60+
if P.dims == 1
61+
for j in axes(X,3)
62+
Y[:,:,j] = P.factorization \ X[:,:,j]
63+
end
64+
elseif P.dims == 2
65+
for k in axes(X,1)
66+
Y[k,:,:] = P.factorization \ X[k,:,:]
67+
end
68+
else
69+
@assert P.dims == 3
70+
for k in axes(X,1), j in axes(X,2)
71+
Y[k,j,:] = P.factorization \ X[k,j,:]
72+
end
73+
end
74+
Y
75+
end
76+
77+
function *(P::InvPlan, X::AbstractArray)
78+
for d in P.dims
79+
X = InvPlan(P.factorization, d) * X
80+
end
81+
X
82+
end
83+
84+
85+
"""
86+
MulPlan(matrix, dims)
87+
88+
Takes a matrix and supports it applied to different dimensions.
89+
"""
90+
struct MulPlan{T, Fact, Dims} <: Plan{T}
91+
matrix::Fact
92+
dims::Dims
93+
end
94+
95+
MulPlan(fact, dims) = MulPlan{eltype(fact), typeof(fact), typeof(dims)}(fact, dims)
96+
97+
function *(P::MulPlan{<:Any,<:Any,Int}, x::AbstractVector)
98+
@assert P.dims == 1
99+
P.matrix * x
100+
end
101+
102+
function *(P::MulPlan{<:Any,<:Any,Int}, X::AbstractMatrix)
103+
if P.dims == 1
104+
P.matrix * X
105+
else
106+
@assert P.dims == 2
107+
permutedims(P.matrix * permutedims(X))
108+
end
109+
end
110+
111+
function *(P::MulPlan{<:Any,<:Any,Int}, X::AbstractArray{<:Any,3})
112+
Y = similar(X)
113+
if P.dims == 1
114+
for j in axes(X,3)
115+
Y[:,:,j] = P.matrix * X[:,:,j]
116+
end
117+
elseif P.dims == 2
118+
for k in axes(X,1)
119+
Y[k,:,:] = P.matrix * X[k,:,:]
120+
end
121+
else
122+
@assert P.dims == 3
123+
for k in axes(X,1), j in axes(X,2)
124+
Y[k,j,:] = P.matrix * X[k,j,:]
125+
end
126+
end
127+
Y
128+
end
129+
130+
function *(P::MulPlan, X::AbstractArray)
131+
for d in P.dims
132+
X = MulPlan(P.matrix, d) * X
133+
end
134+
X
135+
end
136+
137+
*(A::AbstractMatrix, P::MulPlan) = MulPlan(A*P.matrix, P.dims)
138+
139+
inv(P::MulPlan) = InvPlan(factorize(P.matrix), P.dims)
140+
inv(P::InvPlan) = MulPlan(P.factorization, P.dims)

test/test_splines.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ContinuumArrays, LinearAlgebra, Base64, FillArrays, QuasiArrays, BandedMatrices, Test
22
using QuasiArrays: ApplyQuasiArray, ApplyStyle, MemoryLayout, mul, MulQuasiMatrix, Vec
33
import LazyArrays: MulStyle, LdivStyle, arguments, applied, apply
4-
import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout, SubBasisLayout, AdjointMappedBasisLayout, MappedBasisLayout
4+
import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout, SubBasisLayout, AdjointMappedBasisLayout, MappedBasisLayout, plan_grid_transform
55

66
@testset "Splines" begin
77
@testset "HeavisideSpline" begin
@@ -436,6 +436,25 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
436436
@test L[y,:] \ (y .+ y) L[y,:] \ (2y)
437437
@test L[y,:] \ (y .- y) zeros(10)
438438
end
439+
440+
@testset "transform" begin
441+
x = Inclusion(0..1)
442+
y = 2x .- 1
443+
L = LinearSpline(range(-1,stop=1,length=10))
444+
g,P = plan_grid_transform(L[y,:], (10,))
445+
X = cos.(g)
446+
@test L[y,:][g,:] * (P * X) X
447+
@test P \ (P * X) P * (P \ X) X
448+
449+
g,P = plan_grid_transform(L[y,:], (10,10))
450+
X = cos.(g .+ g')
451+
@test L[y,:][g,:]*(P * X)*L[y,:][g,:]' X
452+
@test P \ (P * X) P * (P \ X) X
453+
454+
g,P = plan_grid_transform(L[y,:], (10,10,10))
455+
X = randn(10,10,10)
456+
@test P \ (P * X) P * (P \ X) X
457+
end
439458
end
440459

441460
@testset "diff" begin

0 commit comments

Comments
 (0)