Skip to content

Commit ff83d09

Browse files
committed
Add MulQuasiArray
1 parent ac247f2 commit ff83d09

File tree

6 files changed

+151
-23
lines changed

6 files changed

+151
-23
lines changed

src/ContinuumArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import BandedMatrices: AbstractBandedLayout, _BandedMatrix
1010
include("QuasiArrays/QuasiArrays.jl")
1111
using .QuasiArrays
1212
import .QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, slice, QSlice, SubQuasiArray,
13-
QuasiDiagonal
13+
QuasiDiagonal, MulQuasiArray, MulQuasiMatrix, MulQuasiVector, QuasiMatMulMat
1414

1515
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, JacobiWeight, Jacobi, Legendre
1616

src/QuasiArrays/QuasiArrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Base: @_inline_meta, DimOrInd, OneTo, @_propagate_inbounds_meta, @_noinli
99
index_shape, to_shape, unsafe_length, @nloops, @ncall, Slice, unalias
1010
import Base: ViewIndex, Slice, ScalarIndex, RangeIndex, view, viewindexing, ensure_indexable, index_dimsum,
1111
check_parent_index_match, reindex, _isdisjoint, unsafe_indices,
12-
parentindices, reverse
12+
parentindices, reverse, ndims
1313
import Base: *, /, \, +, -, inv
1414
import Base: exp, log, sqrt,
1515
cos, sin, tan, csc, sec, cot,
@@ -23,7 +23,8 @@ import Base.Broadcast: materialize
2323
import LinearAlgebra: transpose, adjoint, checkeltype_adjoint, checkeltype_transpose, Diagonal,
2424
AbstractTriangular
2525

26-
import LazyArrays: MemoryLayout, UnknownLayout, Mul2
26+
import LazyArrays: MemoryLayout, UnknownLayout, Mul2, _materialize, MulLayout, , rmaterialize,
27+
_rmaterialize, _lmaterialize, flatten, _flatten
2728

2829
export AbstractQuasiArray, AbstractQuasiMatrix, AbstractQuasiVector, materialize
2930

src/QuasiArrays/matmul.jl

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
const QuasiArrayMulArray{styleA, styleB, p, q, T, V} =
33
Mul2{styleA, styleB, <:AbstractQuasiArray{T,p}, <:AbstractArray{V,q}}
44

5+
const ArrayMulQuasiArray{styleA, styleB, p, q, T, V} =
6+
Mul2{styleA, styleB, <:AbstractArray{T,p}, <:AbstractQuasiArray{V,q}}
7+
58
const QuasiArrayMulQuasiArray{styleA, styleB, p, q, T, V} =
69
Mul2{styleA, styleB, <:AbstractQuasiArray{T,p}, <:AbstractQuasiArray{V,q}}
710
####
@@ -13,7 +16,7 @@ const QuasiMatMulVec{styleA, styleB, T, V} = QuasiArrayMulArray{styleA, styleB,
1316
function getindex(M::QuasiMatMulVec, k::Real)
1417
A,B = M.factors
1518
ret = zero(eltype(M))
16-
@inbounds for j = 1:size(A,2)
19+
@inbounds for j in axes(A,2)
1720
ret += A[k,j] * B[j]
1821
end
1922
ret
@@ -22,7 +25,7 @@ end
2225
function getindex(M::QuasiMatMulVec, k::AbstractArray)
2326
A,B = M.factors
2427
ret = zeros(eltype(M),length(k))
25-
@inbounds for j = 1:size(A,2)
28+
@inbounds for j in axes(A,2)
2629
ret .+= view(A,k,j) .* B[j]
2730
end
2831
ret
@@ -41,3 +44,114 @@ inv(A::AbstractQuasiArray) = materialize(Inv(A))
4144

4245
*(A::AbstractQuasiArray, B::Mul) = materialize(Mul(A, B.factors...))
4346
*(A::Mul, B::AbstractQuasiArray) = materialize(Mul(A.factors..., B))
47+
48+
49+
####
50+
# MulQuasiArray
51+
#####
52+
53+
struct MulQuasiArray{T, N, MUL<:Mul} <: AbstractQuasiArray{T,N}
54+
mul::MUL
55+
end
56+
57+
const MulQuasiVector{T, MUL<:Mul} = MulQuasiArray{T, 1, MUL}
58+
const MulQuasiMatrix{T, MUL<:Mul} = MulQuasiArray{T, 2, MUL}
59+
60+
const Vec = MulQuasiVector
61+
62+
63+
MulQuasiArray{T,N}(M::MUL) where {T,N,MUL<:Mul} = MulQuasiArray{T,N,MUL}(M)
64+
MulQuasiArray{T}(M::Mul) where {T} = MulQuasiArray{T,ndims(M)}(M)
65+
MulQuasiArray(M::Mul) = MulQuasiArray{eltype(M)}(M)
66+
MulQuasiVector(M::Mul) = MulQuasiVector{eltype(M)}(M)
67+
MulQuasiMatrix(M::Mul) = MulQuasiMatrix{eltype(M)}(M)
68+
69+
MulQuasiArray(factors...) = MulQuasiArray(Mul(factors...))
70+
MulQuasiArray{T}(factors...) where T = MulQuasiArray{T}(Mul(factors...))
71+
MulQuasiArray{T,N}(factors...) where {T,N} = MulQuasiArray{T,N}(Mul(factors...))
72+
MulQuasiVector(factors...) = MulQuasiVector(Mul(factors...))
73+
MulQuasiMatrix(factors...) = MulQuasiMatrix(Mul(factors...))
74+
75+
_MulArray(factors...) = MulQuasiArray(factors...)
76+
_MulArray(factors::AbstractArray...) = MulArray(factors...)
77+
78+
axes(A::MulQuasiArray) = axes(A.mul)
79+
size(A::MulQuasiArray) = map(length, axes(A))
80+
81+
IndexStyle(::MulQuasiArray{<:Any,1}) = IndexLinear()
82+
83+
==(A::MulQuasiArray, B::MulQuasiArray) = A.mul == B.mul
84+
85+
@propagate_inbounds getindex(A::MulQuasiArray, kj::Real...) = A.mul[kj...]
86+
87+
*(A::MulQuasiArray, B::MulQuasiArray) = A.mul * B.mul
88+
*(A::MulQuasiArray, B::Mul) = A.mul * B
89+
*(A::Mul, B::MulQuasiArray) = A * B.mul
90+
*(A::MulQuasiArray, B::AbstractQuasiArray) = A.mul * B
91+
*(A::AbstractQuasiArray, B::MulQuasiArray) = A * B.mul
92+
*(A::MulQuasiArray, B::AbstractArray) = A.mul * B
93+
*(A::AbstractArray, B::MulQuasiArray) = A * B.mul
94+
95+
adjoint(A::MulQuasiArray) = MulQuasiArray(reverse(adjoint.(A.mul.factors))...)
96+
transpose(A::MulQuasiArray) = MulQuasiArray(reverse(transpose.(A.mul.factors))...)
97+
98+
99+
MemoryLayout(M::MulQuasiArray) = MulLayout(MemoryLayout.(M.mul.factors))
100+
101+
102+
103+
####
104+
# Matrix * Array
105+
####
106+
107+
_flatten(A::MulQuasiArray, B...) = _flatten(A.mul.factors..., B...)
108+
flatten(A::MulQuasiArray) = MulQuasiArray(Mul(_flatten(A.mul.factors...)))
109+
110+
111+
# the default is always Array
112+
113+
_materialize(M::QuasiArrayMulArray, _) = MulQuasiArray(M)
114+
_materialize(M::ArrayMulQuasiArray, _) = MulQuasiArray(M)
115+
_materialize(M::QuasiArrayMulQuasiArray, _) = MulQuasiArray(M)
116+
117+
118+
119+
# if multiplying two MulQuasiArrays simplifies the arguments, we materialize,
120+
# otherwise we leave it as a lazy object
121+
_mulquasi_join(As, M::MulQuasiArray, Cs) = MulQuasiArray(As..., M.mul.factors..., Cs...)
122+
_mulquasi_join(As, B, Cs) = *(As..., B, Cs...)
123+
124+
125+
function _materialize(M::Mul2{<:Any,<:Any,<:MulQuasiArray,<:MulQuasiArray}, _)
126+
As, Bs = M.factors
127+
_mul_join(reverse(tail(reverse(As))), last(As) * first(Bs), tail(Bs))
128+
end
129+
130+
131+
function _materialize(M::Mul2{<:Any,<:Any,<:MulQuasiArray,<:AbstractQuasiArray}, _)
132+
As, B = M.factors
133+
(As.mul.factors..., B)
134+
end
135+
136+
function _materialize(M::Mul2{<:Any,<:Any,<:AbstractQuasiArray,<:MulQuasiArray}, _)
137+
A, Bs = M.factors
138+
*(A, Bs.mul.factors...)
139+
end
140+
141+
# A MulQuasiArray can't be materialized further left-to-right, so we do right-to-left
142+
function _materialize(M::Mul2{<:Any,<:Any,<:MulQuasiArray,<:AbstractArray}, _)
143+
As, B = M.factors
144+
(As.mul.factors..., B)
145+
end
146+
147+
function _lmaterialize(A::MulQuasiArray, B, C...)
148+
As = A.mul.factors
149+
flatten(_MulArray(reverse(tail(reverse(As)))..., _lmaterialize(last(As), B, C...)))
150+
end
151+
152+
153+
154+
function _rmaterialize(Z::MulQuasiArray, Y, W...)
155+
Zs = Z.mul.factors
156+
flatten(_MulArray(_rmaterialize(first(Zs), Y, W...), tail(Zs)...))
157+
end

src/bases/jacobi.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
function materialize(M::Mul2{<:Any,<:Any,<:Derivative{<:Any,<:ChebyshevInterval},<:Jacobi})
4343
D, S = M.factors
4444
A = PInv(Jacobi(S.b+1,S.a+1))*D*S
45-
Mul(Jacobi(S.b+1,S.a+1), A)
45+
MulQuasiMatrix(Jacobi(S.b+1,S.a+1), A)
4646
end
4747

4848
# pinv(Legendre())D*W*Jacobi(true,true)
@@ -62,7 +62,7 @@ function materialize(M::Mul{<:Tuple,<:Tuple{<:Derivative{<:Any,<:ChebyshevInterv
6262
w = parent(W)
6363
(w.a && S.a && w.b && S.b) || throw(ArgumentError())
6464
A = pinv(Legendre{eltype(M)}())*D*W*S
65-
Mul(Legendre(), A)
65+
MulQuasiMatrix(Legendre(), A)
6666
end
6767

6868
function materialize(M::Mul{<:Tuple,<:Tuple{<:PInv{<:Any,<:Jacobi{Bool}},

src/bases/splines.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ function materialize(M::Mul2{<:Any,<:Any,<:QuasiAdjoint{<:Any,<:HeavisideSpline}
7878
end
7979

8080
## Derivative
81-
function copyto!(dest::Mul2{<:Any,<:Any,<:HeavisideSpline},
81+
function copyto!(dest::MulQuasiMatrix{<:Any,<:Mul2{<:Any,<:Any,<:HeavisideSpline}},
8282
M::Mul2{<:Any,<:Any,<:Derivative,<:LinearSpline})
8383
D, L = M.factors
84-
H, A = dest.factors
84+
H, A = dest.mul.factors
8585
x = H.points
8686

8787
axes(dest) == axes(M) || throw(DimensionMismatch("axes must be same"))
@@ -98,7 +98,7 @@ end
9898
function similar(M::Mul2{<:Any,<:Any,<:Derivative,<:LinearSpline}, ::Type{T}) where T
9999
D, B = M.factors
100100
n = size(B,2)
101-
Mul(HeavisideSpline{T}(B.points),
101+
MulQuasiMatrix(HeavisideSpline{T}(B.points),
102102
BandedMatrix{T}(undef, (n-1,n), (0,1)))
103103
end
104104

test/runtests.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, Test,
22
InfiniteArrays
33
import ContinuumArrays: ℵ₁, materialize
4-
import ContinuumArrays.QuasiArrays: SubQuasiArray
4+
import ContinuumArrays.QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec
5+
import LazyArrays: rmaterialize,
56

67
@testset "DiracDelta" begin
78
δ = DiracDelta(-1..3)
@@ -75,7 +76,16 @@ end
7576
L = LinearSpline([1,2,3])
7677
f = L*[1,2,4]
7778
D = Derivative(axes(L,1))
79+
@test D*L isa MulQuasiMatrix
80+
81+
fp = (D*L)*[1,2,4]
82+
@test fp isa Vec
83+
@test length(fp.mul.factors) == 2
84+
@test fp[1.1] 1
85+
@test fp[2.2] 2
86+
7887
fp = D*f
88+
@test length(fp.mul.factors) == 2
7989

8090
@test fp[1.1] 1
8191
@test fp[2.2] 2
@@ -86,17 +96,19 @@ end
8696
L = LinearSpline(0:2)
8797

8898
D = Derivative(axes(L,1))
89-
M = materialize(Mul(D',D,L))
90-
DL = D*L
91-
@test M.factors == tuple(D', (D*L).factors...)
99+
M = rmaterialize(Mul(D',D*L))
100+
@test length(M.mul.factors) == 3
101+
@test last(M.mul.factors) isa BandedMatrix
92102

93-
@test materialize(Mul(L', D', D, L)) == (L'D'*D*L) ==
94-
[1.0 -1 0; -1.0 2.0 -1.0; 0.0 -1.0 1.0]
103+
@test M.mul.factors == rmaterialize(Mul(D',D,L)).mul.factors ==
104+
(D',D,L).mul.factors == *(D',D,L).mul.factors
95105

96-
@test materialize(Mul(L', D', D, L)) isa BandedMatrix
97-
@test (L'D'*D*L) isa BandedMatrix
106+
@test (L'D') isa MulQuasiMatrix
107+
A = (L'D') * (D*L)
108+
@test A == (D*L)'*(D*L) == [1.0 -1 0; -1.0 2.0 -1.0; 0.0 -1.0 1.0]
98109

99-
@test bandwidths(materialize(L'D'*D*L)) == (1,1)
110+
@test A isa MulArray
111+
@test bandwidths(A) == (1,1)
100112
end
101113

102114
@testset "Views" begin
@@ -121,7 +133,7 @@ end
121133

122134
@testset "Subindex of splines" begin
123135
L = LinearSpline(range(0,stop=1,length=10))
124-
@test L[:,2:end-1] isa Mul
136+
@test L[:,2:end-1] isa MulQuasiMatrix
125137
@test_broken L[:,2:end-1][0.1,1] == L[0.1,2]
126138
v = randn(8)
127139
f = L[:,2:end-1] * v
@@ -132,7 +144,7 @@ end
132144
L = LinearSpline(range(0,stop=1,length=10))
133145
B = L[:,2:end-1] # Zero dirichlet by dropping first and last spline
134146
D = Derivative(axes(L,1))
135-
Δ = -(B'D'D*B) # Weak Laplacian
147+
Δ = -((B'D')*(D*B)) # Weak Laplacian
136148

137149
f = L*exp.(L.points) # project exp(x)
138150
u = B *\ (B'f))
@@ -145,7 +157,8 @@ end
145157
L = LinearSpline(range(0,stop=1,length=10))
146158
B = L[:,2:end-1] # Zero dirichlet by dropping first and last spline
147159
D = Derivative(axes(L,1))
148-
A = -(B'D'D*B) + 100^2*B'B # Weak Laplacian
160+
161+
A = -((B'D')*(D*B)) + 100^2*B'B # Weak Laplacian
149162

150163
f = L*exp.(L.points) # project exp(x)
151164
u = B * (A \ (B'f))
@@ -159,7 +172,7 @@ end
159172
D = Derivative(axes(W,1))
160173
P = Legendre()
161174

162-
A = @inferred(PInv(Jacobi(2,2))*D*S)
175+
A = @inferred(PInv(Jacobi(2,2))*(D*S))
163176
@test A isa BandedMatrix
164177
@test size(A) == (∞,∞)
165178
@test A[1:10,1:10] == diagm(1 => 1:0.5:5)

0 commit comments

Comments
 (0)