Skip to content

Commit e647465

Browse files
committed
auto-fullmaterialize
1 parent 171c091 commit e647465

File tree

4 files changed

+65
-64
lines changed

4 files changed

+65
-64
lines changed

src/ContinuumArrays.jl

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -53,50 +53,6 @@ end
5353

5454

5555

56-
57-
most(a) = reverse(tail(reverse(a)))
58-
59-
MulQuasiOrArray = Union{MulArray,MulQuasiArray}
60-
61-
_factors(M::MulQuasiOrArray) = M.applied.args
62-
_factors(M) = (M,)
63-
64-
_flatten() = ()
65-
_flatten(A, B...) = (A, _flatten(B...)...)
66-
_flatten(A::Mul, B...) = _flatten(A.args..., B...)
67-
flatten(A::Mul) = Mul(_flatten(A.args...)...)
68-
69-
_flatten(A::MulQuasiArray, B...) = _flatten(A.applied, B...)
70-
flatten(A::MulQuasiArray) = MulQuasiArray(flatten(A.applied))
71-
72-
function fullmaterialize(M::Applied{<:Any,typeof(*)})
73-
M_mat = materialize(flatten(M))
74-
typeof(M_mat) <: MulQuasiOrArray || return M_mat
75-
typeof(M_mat.applied) == typeof(M) || return(fullmaterialize(M_mat))
76-
77-
ABC = M_mat.applied.args
78-
length(ABC) 2 && return M_mat
79-
80-
AB = most(ABC)
81-
Mhead = fullmaterialize(Mul(AB...))
82-
83-
typeof(_factors(Mhead)) == typeof(AB) ||
84-
return fullmaterialize(Mul(_factors(Mhead)..., last(ABC)))
85-
86-
BC = tail(ABC)
87-
Mtail = fullmaterialize(Mul(BC...))
88-
typeof(_factors(Mtail)) == typeof(BC) ||
89-
return fullmaterialize(Mul(first(ABC), _factors(Mtail)...))
90-
91-
first(ABC) * Mtail
92-
end
93-
94-
fullmaterialize(M::ApplyQuasiArray) = fullmaterialize(M.applied)
95-
fullmaterialize(M) = M
96-
97-
materialize(M::Applied{<:Any,typeof(*),<:Tuple{Vararg{<:Union{Adjoint,QuasiAdjoint,QuasiDiagonal}}}}) =
98-
materialize(Mul(reverse(adjoint.(M.args))...))'
99-
10056
include("operators.jl")
10157
include("bases/bases.jl")
10258

src/QuasiArrays/QuasiArrays.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import Base.Broadcast: materialize
2323
import LinearAlgebra: transpose, adjoint, checkeltype_adjoint, checkeltype_transpose, Diagonal,
2424
AbstractTriangular, pinv, inv
2525

26-
import LazyArrays: MemoryLayout, UnknownLayout, Mul2, _materialize, MulLayout, ,
26+
import LazyArrays: MemoryLayout, UnknownLayout, Mul2, _materialize, MulLayout, ,
2727
_lmaterialize, InvOrPInv, ApplyStyle,
2828
LayoutApplyStyle, Applied
2929

@@ -50,4 +50,8 @@ include("abstractquasiarraymath.jl")
5050
include("quasiadjtrans.jl")
5151
include("quasidiagonal.jl")
5252

53+
54+
materialize(M::Applied{<:Any,typeof(*),<:Tuple{Vararg{<:Union{Adjoint,QuasiAdjoint,QuasiDiagonal}}}}) =
55+
materialize(Mul(reverse(adjoint.(M.args))...))'
56+
5357
end

src/QuasiArrays/matmul.jl

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ struct LazyQuasiArrayApplyStyle <: ApplyStyle end
6363
ndims(M::Applied{LazyQuasiArrayApplyStyle,typeof(*)}) = ndims(last(M.args))
6464

6565

66-
*(A::AbstractQuasiArray, B...) = materialize(Mul(A,B...))
67-
*(A::AbstractQuasiArray, B::AbstractQuasiArray, C...) = materialize(Mul(A,B,C...))
68-
*(A::AbstractArray, B::AbstractQuasiArray, C...) = materialize(Mul(A,B,C...))
66+
*(A::AbstractQuasiArray, B...) = fullmaterialize(materialize(Mul(A,B...)))
67+
*(A::AbstractQuasiArray, B::AbstractQuasiArray, C...) = fullmaterialize(materialize(Mul(A,B,C...)))
68+
*(A::AbstractArray, B::AbstractQuasiArray, C...) = fullmaterialize(materialize(Mul(A,B,C...)))
6969

7070
pinv(A::AbstractQuasiArray) = materialize(PInv(A))
7171
inv(A::AbstractQuasiArray) = materialize(Inv(A))
@@ -74,8 +74,8 @@ inv(A::AbstractQuasiArray) = materialize(Inv(A))
7474
\(A::AbstractQuasiArray, B::AbstractQuasiArray) = materialize(Ldiv(A,B))
7575

7676

77-
*(A::AbstractQuasiArray, B::Mul, C...) = materialize(Mul(A, B.args..., C...))
78-
*(A::Mul, B::AbstractQuasiArray, C...) = materialize(Mul(A.args..., B, C...))
77+
*(A::AbstractQuasiArray, B::Mul, C...) = fullmaterialize(materialize(Mul(A, B.args..., C...)))
78+
*(A::Mul, B::AbstractQuasiArray, C...) = fullmaterialize(materialize(Mul(A.args..., B, C...)))
7979

8080

8181
struct ApplyQuasiArray{T, N, App<:Applied} <: AbstractQuasiArray{T,N}
@@ -142,11 +142,53 @@ MulQuasiMatrix(factors...) = MulQuasiMatrix(Mul(factors...))
142142
_MulArray(factors...) = MulQuasiArray(factors...)
143143
_MulArray(factors::AbstractArray...) = MulArray(factors...)
144144

145-
*(A::MulQuasiArray, B::MulQuasiArray) = materialize(Mul(A.applied.args..., B.applied.args...))
146-
*(A::MulQuasiArray, B::AbstractQuasiArray) = materialize(Mul(A.applied.args..., B))
147-
*(A::AbstractQuasiArray, B::MulQuasiArray) = materialize(Mul(A, B.applied.args...))
148-
*(A::MulQuasiArray, B::AbstractArray) = materialize(Mul(A.applied.args..., B))
149-
*(A::AbstractArray, B::MulQuasiArray) = materialize(Mul(A, B.applied.args...))
145+
most(a) = reverse(tail(reverse(a)))
146+
147+
MulQuasiOrArray = Union{MulArray,MulQuasiArray}
148+
149+
_factors(M::MulQuasiOrArray) = M.applied.args
150+
_factors(M) = (M,)
151+
152+
_flatten() = ()
153+
_flatten(A, B...) = (A, _flatten(B...)...)
154+
_flatten(A::Mul, B...) = _flatten(A.args..., B...)
155+
flatten(A::Mul) = Mul(_flatten(A.args...)...)
156+
157+
_flatten(A::MulQuasiArray, B...) = _flatten(A.applied, B...)
158+
flatten(A::MulQuasiArray) = MulQuasiArray(flatten(A.applied))
159+
160+
function fullmaterialize(M::Applied{<:Any,typeof(*)})
161+
M_mat = materialize(flatten(M))
162+
typeof(M_mat) <: MulQuasiOrArray || return M_mat
163+
typeof(M_mat.applied) == typeof(M) || return(fullmaterialize(M_mat))
164+
165+
ABC = M_mat.applied.args
166+
length(ABC) 2 && return M_mat
167+
168+
AB = most(ABC)
169+
Mhead = fullmaterialize(Mul(AB...))
170+
171+
typeof(_factors(Mhead)) == typeof(AB) ||
172+
return fullmaterialize(Mul(_factors(Mhead)..., last(ABC)))
173+
174+
BC = tail(ABC)
175+
Mtail = fullmaterialize(Mul(BC...))
176+
typeof(_factors(Mtail)) == typeof(BC) ||
177+
return fullmaterialize(Mul(first(ABC), _factors(Mtail)...))
178+
179+
first(ABC) * Mtail
180+
end
181+
182+
fullmaterialize(M::ApplyQuasiArray) = fullmaterialize(M.applied)
183+
fullmaterialize(M) = M
184+
185+
*(A::MulQuasiArray, B::MulQuasiArray) = fullmaterialize(materialize(Mul(A.applied.args..., B.applied.args...)))
186+
*(A::MulQuasiArray, B::AbstractQuasiArray) = fullmaterialize(materialize(Mul(A.applied.args..., B)))
187+
*(A::AbstractQuasiArray, B::MulQuasiArray) = fullmaterialize(materialize(Mul(A, B.applied.args...)))
188+
*(A::MulQuasiArray, B::AbstractArray) = fullmaterialize(materialize(Mul(A.applied.args..., B)))
189+
*(A::AbstractArray, B::MulQuasiArray) = fullmaterialize(materialize(Mul(A, B.applied.args...)))
190+
191+
150192

151193
adjoint(A::MulQuasiArray) = MulQuasiArray(reverse(adjoint.(A.applied.args))...)
152194
transpose(A::MulQuasiArray) = MulQuasiArray(reverse(transpose.(A.applied.args))...)

test/runtests.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, Test,
2-
InfiniteArrays
1+
using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, Test, InfiniteArrays
32
import ContinuumArrays: ℵ₁, materialize
43
import ContinuumArrays.QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle
54
import LazyArrays: MemoryLayout, ApplyStyle
@@ -96,7 +95,7 @@ end
9695
@test M.style isa LazyQuasiArrayApplyStyle
9796
@test eltype(materialize(M)) == Float64
9897

99-
fp = fullmaterialize(D*L*[1,2,4])
98+
fp = D*L*[1,2,4]
10099

101100
@test eltype(fp) == Float64
102101

@@ -106,7 +105,7 @@ end
106105
@test fp[2.2] 2
107106

108107

109-
fp = fullmaterialize(D*f)
108+
fp = D*f
110109
@test length(fp.applied.args) == 2
111110
@test fp[1.1] 1
112111
@test fp[2.2] 2
@@ -118,13 +117,13 @@ end
118117

119118
D = Derivative(axes(L,1))
120119

121-
M = fullmaterialize(ContinuumArrays.flatten(Mul(D',D*L)))
122-
@test length(M.applied.args) == 3
123-
@test last(M.applied.args) isa BandedMatrix
120+
M = ContinuumArrays.QuasiArrays.flatten(Mul(D',D*L))
121+
@test length(M.args) == 3
122+
@test last(M.args) isa BandedMatrix
124123

125124
@test (L'D') isa MulQuasiMatrix
126125
A = (L'D') * (D*L)
127-
@test fullmaterialize(A) == fullmaterialize((D*L)'*(D*L)) == [1.0 -1 0; -1.0 2.0 -1.0; 0.0 -1.0 1.0]
126+
@test A == (D*L)'*(D*L) == [1.0 -1 0; -1.0 2.0 -1.0; 0.0 -1.0 1.0]
128127
@test_skip bandwidths(A) == (1,1)
129128
end
130129

@@ -161,7 +160,7 @@ end
161160
L = LinearSpline(range(0,stop=1,length=10))
162161
B = L[:,2:end-1] # Zero dirichlet by dropping first and last spline
163162
D = Derivative(axes(L,1))
164-
Δ = -fullmaterialize((B'D')*(D*B)) # Weak Laplacian
163+
Δ = -((B'D')*(D*B)) # Weak Laplacian
165164

166165
f = L*exp.(L.points) # project exp(x)
167166
u = B *\ (B'f))

0 commit comments

Comments
 (0)