Skip to content

Commit 3390570

Browse files
authored
Remove transform_ldiv_if_columns to match the *_size pattern (#161)
* Remove transform_ldiv_if_columns to match the *_size pattern * Update bases.jl * Simplify basis_broadcast and add tocoefficients default * v0.16 * add simplify for ldiv * Update Project.toml
1 parent 1ef9856 commit 3390570

File tree

4 files changed

+71
-50
lines changed

4 files changed

+71
-50
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.15.2"
3+
version = "0.16"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -36,7 +36,7 @@ FillArrays = "1.0"
3636
InfiniteArrays = "0.12, 0.13"
3737
Infinities = "0.1"
3838
IntervalSets = "0.7"
39-
LazyArrays = "1.6.1"
39+
LazyArrays = "1.7"
4040
Makie = "0.19"
4141
QuasiArrays = "0.11.1"
4242
RecipesBase = "1.0"

src/ContinuumArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import LinearAlgebra: pinv, inv, dot, norm2, ldiv!, mul!
1313
import BandedMatrices: AbstractBandedLayout, _BandedMatrix
1414
import BlockArrays: block, blockindex, unblock, blockedrange, _BlockedUnitRange, _BlockArray
1515
import FillArrays: AbstractFill, getindex_value, SquareEye
16-
import ArrayLayouts: mul, ZerosLayout, ScalarLayout, AbstractStridedLayout, check_mul_axes
16+
import ArrayLayouts: mul, ldiv, ZerosLayout, ScalarLayout, AbstractStridedLayout, check_mul_axes, check_ldiv_axes
1717
import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclusion, SubQuasiArray,
1818
QuasiDiagonal, MulQuasiArray, MulQuasiMatrix, MulQuasiVector, QuasiMatMulMat, QuasiArrayLayout,
1919
ApplyQuasiArray, ApplyQuasiMatrix, LazyQuasiArrayApplyStyle, AbstractQuasiArrayApplyStyle, AbstractQuasiLazyLayout,

src/bases/bases.jl

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,33 @@ equals_layout(::SubBasisLayouts, ::SubBasisLayouts, A::SubQuasiArray, B::SubQuas
5858
equals_layout(::MappedBasisLayouts, ::MappedBasisLayouts, A::SubQuasiArray, B::SubQuasiArray) = parentindices(A) == parentindices(B) && demap(A) == demap(B)
5959
equals_layout(::AbstractWeightedBasisLayout, ::AbstractWeightedBasisLayout, A, B) = weight(A) == weight(B) && unweighted(A) == unweighted(B)
6060

61-
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)}}) = +(broadcast(\,Ref(L.A),arguments(L.B))...)
62-
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)},<:Any,<:AbstractQuasiVector}) =
63-
transform_ldiv(L.A, L.B)
6461
for op in (:+, :-)
65-
@eval @inline copy(L::Ldiv{Lay,BroadcastLayout{typeof($op)},<:Any,<:AbstractQuasiVector}) where Lay<:MappedBasisLayouts =
66-
copy(Ldiv{Lay,LazyLayout}(L.A,L.B))
62+
@eval begin
63+
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof($op)}}) = basis_broadcast_ldiv_size($op, size(L), L.A, L.B)
64+
@inline copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof($op)}}) = copy(Ldiv{BasisLayout,BroadcastLayout{typeof($op)}}(L.A, L.B))
65+
basis_broadcast_ldiv_size(::typeof($op), ::Tuple{Integer}, A, B) = transform_ldiv(A, B)
66+
end
6767
end
6868

69-
@inline function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)}})
70-
a,b = arguments(L.B)
71-
(L.A\a)-(L.A\b)
69+
basis_broadcast_ldiv_size(::typeof(+), _, A, B) = +(broadcast(\,Ref(A),arguments(B))...)
70+
71+
72+
73+
@inline function basis_broadcast_ldiv_size(::typeof(-), _, A, B)
74+
a,b = arguments(B)
75+
(A\a)-(A\b)
7276
end
7377

74-
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)},<:Any,<:AbstractQuasiVector}) =
75-
transform_ldiv(L.A, L.B)
7678

79+
# TODO: remove as Not type stable
80+
simplifiable(L::Ldiv{<:AbstractBasisLayout,<:AbstractBasisLayout}) = Val(L.A == L.B)
7781
@inline function copy(P::Ldiv{<:AbstractBasisLayout,<:AbstractBasisLayout})
7882
A, B = P.A, P.B
7983
A == B || throw(ArgumentError("Override copy for $(typeof(A)) \\ $(typeof(B))"))
8084
SquareEye{eltype(eltype(P))}((axes(A,2),)) # use double eltype for array-valued
8185
end
86+
87+
simplifiable(L::Ldiv{<:SubBasisLayouts,<:SubBasisLayouts}) = Val(parent(L.A) == parent(L.B))
8288
@inline function copy(P::Ldiv{<:SubBasisLayouts,<:SubBasisLayouts})
8389
A, B = P.A, P.B
8490
parent(A) == parent(B) ||
@@ -91,20 +97,18 @@ end
9197
demap(A)\demap(B)
9298
end
9399

94-
function transform_ldiv_if_columns(P::Ldiv{<:MappedBasisLayouts,<:Any,<:Any,<:AbstractQuasiVector}, ::OneTo)
95-
A,B = P.A, P.B
96-
demap(A) \ B[invmap(basismap(A))]
97-
end
98-
99-
function transform_ldiv_if_columns(P::Ldiv{<:MappedBasisLayouts,<:Any,<:Any,<:AbstractQuasiMatrix}, ::OneTo)
100-
A,B = P.A, P.B
101-
demap(A) \ B[invmap(basismap(A)),:]
102-
end
100+
copy(P::Ldiv{<:MappedBasisLayouts}) = mapped_ldiv_size(size(P), P.A, P.B)
101+
copy(P::Ldiv{<:MappedBasisLayouts, <:AbstractLazyLayout}) = mapped_ldiv_size(size(P), P.A, P.B)
102+
copy(P::Ldiv{<:MappedBasisLayouts, <:AbstractBasisLayout}) = mapped_ldiv_size(size(P), P.A, P.B)
103+
@inline copy(L::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(hcat)}}) = mapped_ldiv_size(size(L), L.A, L.B)
104+
copy(P::Ldiv{<:MappedBasisLayouts, ApplyLayout{typeof(*)}}) = copy(Ldiv{BasisLayout,ApplyLayout{typeof(*)}}(P.A, P.B))
103105

104-
copy(L::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)}}) = copy(Ldiv{UnknownLayout,ApplyLayout{typeof(*)}}(L.A,L.B))
105-
copy(L::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = transform_ldiv(L.A, L.B)
106+
mapped_ldiv_size(::Tuple{Integer}, A, B) = demap(A) \ B[invmap(basismap(A))]
107+
mapped_ldiv_size(::Tuple{Integer,Int}, A, B) = demap(A) \ B[invmap(basismap(A)),:]
108+
mapped_ldiv_size(::Tuple{Integer,Any}, A, B) = copy(Ldiv{BasisLayout,typeof(MemoryLayout(B))}(A, B))
106109

107-
@inline copy(L::Ldiv{<:AbstractBasisLayout,<:SubBasisLayouts}) = apply(\, L.A, ApplyQuasiArray(L.B))
110+
# following allows us to use simplification
111+
@inline copy(L::Ldiv{Lay,<:SubBasisLayouts}) where Lay<:AbstractBasisLayout = copy(Ldiv{Lay,ApplyLayout{typeof(*)}}(L.A, L.B))
108112
@inline function copy(L::Ldiv{<:SubBasisLayouts,<:AbstractBasisLayout})
109113
P = parent(L.A)
110114
kr, jr = parentindices(L.A)
@@ -146,11 +150,7 @@ _broadcast_mul_ldiv(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) =
146150
_broadcast_mul_ldiv(_, A, B) = copy(Ldiv{typeof(MemoryLayout(A)),UnknownLayout}(A,B))
147151

148152
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
149-
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
150-
151-
# ambiguity
152153
copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
153-
copy(L::Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
154154

155155

156156
# expansion
@@ -257,8 +257,8 @@ end
257257
plan_ldiv(A, B::AbstractQuasiVector) = factorize(A)
258258
plan_ldiv(A, B::AbstractQuasiMatrix) = factorize(A, size(B,2))
259259

260-
transform_ldiv(A::AbstractQuasiArray{T}, B::AbstractQuasiArray{V}, _) where {T,V} = plan_ldiv(A, B) \ B
261-
transform_ldiv(A, B) = transform_ldiv(A, B, size(A))
260+
transform_ldiv_size(_, A::AbstractQuasiArray{T}, B::AbstractQuasiArray{V}) where {T,V} = plan_ldiv(A, B) \ B
261+
transform_ldiv(A, B) = transform_ldiv_size(size(A), A, B)
262262

263263

264264
"""
@@ -291,28 +291,29 @@ in that basis.
291291
"""
292292
function expand(v)
293293
P = basis(v)
294-
ApplyQuasiArray(*, P, P \ v)
294+
ApplyQuasiArray(*, P, tocoefficients(P \ v))
295295
end
296296

297297

298298

299-
copy(L::Ldiv{<:AbstractBasisLayout}) = transform_ldiv(L.A, L.B)
300-
# TODO: redesign to use simplifiable(\, A, B)
301-
copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = transform_ldiv(L.A, L.B)
302-
copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)}}) = copy(Ldiv{UnknownLayout,ApplyLayout{typeof(*)}}(L.A, L.B))
303-
# A BroadcastLayout of unknown function is only knowable pointwise
304-
transform_ldiv_if_columns(L, _) = ApplyQuasiArray(\, L.A, L.B)
305-
transform_ldiv_if_columns(L, ::OneTo) = transform_ldiv(L.A,L.B)
306-
transform_ldiv_if_columns(L) = transform_ldiv_if_columns(L, axes(L.B,2))
307-
copy(L::Ldiv{<:AbstractBasisLayout,<:BroadcastLayout}) = transform_ldiv_if_columns(L)
308-
# Inclusion are QuasiArrayLayout
309-
copy(L::Ldiv{<:AbstractBasisLayout,QuasiArrayLayout}) = transform_ldiv(L.A, L.B)
310-
# Otherwise keep lazy to support, e.g., U\D*T
311-
copy(L::Ldiv{<:AbstractBasisLayout,<:AbstractLazyLayout}) = transform_ldiv_if_columns(L)
312-
copy(L::Ldiv{<:AbstractBasisLayout,ZerosLayout}) = Zeros{eltype(L)}(axes(L)...)
313299

314-
transform_ldiv_if_columns(L::Ldiv{<:Any,<:ApplyLayout{typeof(hcat)}}, ::OneTo) = transform_ldiv(L.A, L.B)
315-
transform_ldiv_if_columns(L::Ldiv{<:Any,<:ApplyLayout{typeof(hcat)}}, _) = hcat((Ref(L.A) .\ arguments(hcat, L.B))...)
300+
301+
@inline copy(L::Ldiv{<:AbstractBasisLayout}) = basis_ldiv_size(size(L), L.A, L.B)
302+
@inline copy(L::Ldiv{<:AbstractBasisLayout,<:AbstractLazyLayout}) = basis_ldiv_size(size(L), L.A, L.B)
303+
@inline function copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)}})
304+
simplifiable(\, L.A, first(arguments(*, L.B))) isa Val{true} && return copy(Ldiv{UnknownLayout,ApplyLayout{typeof(*)}}(L.A, L.B))
305+
basis_ldiv_size(size(L), L.A, L.B)
306+
end
307+
@inline copy(L::Ldiv{<:AbstractBasisLayout,ZerosLayout}) = Zeros{eltype(L)}(axes(L)...)
308+
309+
@inline basis_ldiv_size(_, A, B) = copy(Ldiv{UnknownLayout,typeof(MemoryLayout(B))}(A, B))
310+
@inline basis_ldiv_size(::Tuple{Integer}, A, B) = transform_ldiv(A, B)
311+
@inline basis_ldiv_size(::Tuple{Integer,Int}, A, B) = transform_ldiv(A, B)
312+
313+
@inline copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(hcat)}}) = basis_hcat_ldiv_size(size(L), L.A, L.B)
314+
@inline basis_hcat_ldiv_size(::Tuple{Integer,Int}, A, B) = transform_ldiv(A, B)
315+
@inline basis_hcat_ldiv_size(_, A, B) = hcat((Ref(A) .\ arguments(hcat, B))...)
316+
316317

317318
"""
318319
WeightedFactorization(w, F)
@@ -334,7 +335,14 @@ _factorize(::WeightedBasisLayouts, wS, dims...; kws...) = WeightedFactorization(
334335
##
335336

336337
struct ExpansionLayout{Lay} <: AbstractLazyLayout end
337-
applylayout(::Type{typeof(*)}, ::Lay, ::Union{PaddedLayout,AbstractStridedLayout,ZerosLayout}) where Lay <: AbstractBasisLayout = ExpansionLayout{Lay}()
338+
const CoefficientLayouts = Union{PaddedLayout,AbstractStridedLayout,ZerosLayout}
339+
applylayout(::Type{typeof(*)}, ::Lay, ::CoefficientLayouts) where Lay <: AbstractBasisLayout = ExpansionLayout{Lay}()
340+
341+
tocoefficients(v) = tocoefficients_layout(MemoryLayout(v), v)
342+
tocoefficients_layout(::CoefficientLayouts, v) = v
343+
tocoefficients_layout(_, v) = tocoefficients_size(size(v), v)
344+
tocoefficients_size(::NTuple{N,Int}, v) where N = Array(v)
345+
tocoefficients_size(_, v) = v # the default is to leave it, even though we aren't technically making an ExpansionLayout
338346

339347
"""
340348
basis(v)
@@ -359,7 +367,8 @@ function unweighted(lay::ExpansionLayout, a)
359367
end
360368

361369
LazyArrays._mul_arguments(::ExpansionLayout, A) = LazyArrays._mul_arguments(ApplyLayout{typeof(*)}(), A)
362-
copy(L::Ldiv{Bas,<:ExpansionLayout}) where Bas<:AbstractBasisLayout = copy(Ldiv{Bas,ApplyLayout{typeof(*)}}(L.A, L.B))
370+
copy(L::Ldiv{Lay,<:ExpansionLayout}) where Lay<:AbstractBasisLayout = copy(Ldiv{Lay,ApplyLayout{typeof(*)}}(L.A, L.B))
371+
copy(L::Ldiv{Lay,<:ExpansionLayout}) where Lay<:MappedBasisLayouts = copy(Ldiv{Lay,ApplyLayout{typeof(*)}}(L.A, L.B))
363372
copy(L::Mul{<:ExpansionLayout,Lay}) where Lay = copy(Mul{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
364373
copy(L::Mul{<:ExpansionLayout,Lay}) where Lay<:AbstractLazyLayout = copy(Mul{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
365374

src/operators.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ macro simplify(qt)
5555
Base.copy(M::ContinuumArrays.QMul3{<:$Atyp,<:$Btyp,<:$Ctyp}) = ContinuumArrays.simplify(M)
5656
end)
5757
end
58+
elseif qt.args[1].args[1] == :(\)
59+
mat = qt.args[2]
60+
@assert qt.args[1].args[2].head == :(::)
61+
Aname,Atyp = qt.args[1].args[2].args
62+
Bname,Btyp = qt.args[1].args[3].args
63+
esc(quote
64+
ContinuumArrays.simplifiable(::typeof(\), A::$Atyp, B::$Btyp) = Val(true)
65+
Base.@propagate_inbounds function ContinuumArrays.ldiv($Aname::$Atyp, $Bname::$Btyp)
66+
@boundscheck ContinuumArrays.check_ldiv_axes($Aname, $Bname)
67+
$mat
68+
end
69+
end)
5870
end
5971
end
6072

0 commit comments

Comments
 (0)