Skip to content

Commit 4c44bac

Browse files
authored
Add support for broadcasted expansions (#62)
* Add support for broadcasted expansions * Only do broadcasted when expansion-like * Fix tests * Update runtests.jl
1 parent e289553 commit 4c44bac

File tree

4 files changed

+59
-21
lines changed

4 files changed

+59
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.3.2"
3+
version = "0.3.3"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/ContinuumArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Base: @_inline_meta, @_propagate_inbounds_meta, axes, getindex, convert,
66
getproperty, isone, iszero, zero, abs, <, , >, , string
77
import Base.Broadcast: materialize, BroadcastStyle, broadcasted
88
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport, most, combine_mul_styles, AbstractArrayApplyStyle,
9-
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex,
9+
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex, UnknownLayout,
1010
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles, applylayout,
1111
simplifiable, _simplify
1212
import LinearAlgebra: pinv, dot, norm2

src/bases/bases.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,21 @@ for Bas1 in (:Basis, :WeightedBasis), Bas2 in (:Basis, :WeightedBasis)
9595
end
9696

9797

98+
# multiplication operators, reexpand in basis A
99+
@inline function _broadcast_mul_ldiv(::Tuple{Any,AbstractBasisLayout}, A, B)
100+
a,b = arguments(B)
101+
@assert a isa AbstractQuasiVector # Only works for vec .* mat
102+
ab = (A * (A \ a)) .* b # broadcasted should be overloaded
103+
MemoryLayout(ab) isa BroadcastLayout && error("Overload broadcasted(_, ::typeof(*), ::$(typeof(ab.args[1])), ::$(typeof(b)))")
104+
A \ ab
105+
end
106+
107+
_broadcast_mul_ldiv(_, A, B) = copy(Ldiv{typeof(MemoryLayout(A)),UnknownLayout}(A,B))
108+
109+
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
110+
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
111+
112+
98113
# expansion
99114
_grid(_, P) = error("Overload Grid")
100115
_grid(::MappedBasisLayout, P) = igetindex.(Ref(parentindices(P)[1]), grid(demap(P)))
@@ -203,6 +218,13 @@ for op in (:+, :-)
203218
end
204219
end
205220

221+
function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(*), a::Expansion, f::Expansion)
222+
axes(a,1) == axes(f,1) || throw(DimensionMismatch())
223+
P,c = arguments(f)
224+
(a .* P) * c
225+
end
226+
227+
206228
@eval function ==(f::Expansion, g::Expansion)
207229
S,c = arguments(f)
208230
T,d = arguments(g)

test/runtests.jl

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ContinuumArrays, QuasiArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, FastTransforms, InfiniteArrays, Test, Base64
22
import ContinuumArrays: ℵ₁, materialize, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, ℵ₁,
3-
MappedBasisLayout, AdjointMappedBasisLayout, MappedWeightedBasisLayout, igetindex, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
3+
MappedBasisLayout, AdjointMappedBasisLayout, MappedWeightedBasisLayout, igetindex, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
44
Expansion, basis
55
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
66
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivStyle, MulStyle
@@ -236,7 +236,7 @@ end
236236
L = LinearSpline([1,2,3])
237237
f = L*[1,2,4]
238238
g = L*[5,6,7]
239-
239+
240240
@test f isa Expansion
241241
@test 2f isa Expansion
242242
@test f*2 isa Expansion
@@ -346,7 +346,7 @@ end
346346
@testset "sub-colon" begin
347347
L = LinearSpline([1,2,3])
348348
@test L[:,1][1.1] == (L')[1,:][1.1] == L[1.1,1]
349-
@test L[:,1:2][1.1,:] == (L')[1:2,:][:,1.1] == L[1.1,1:2]
349+
@test L[:,1:2][1.1,:] == (L')[1:2,:][:,1.1] == L[1.1,1:2]
350350
end
351351

352352
@testset "transform" begin
@@ -459,7 +459,7 @@ end
459459
@testset "diff" begin
460460
L = LinearSpline(range(-1,stop=1,length=10))
461461
f = L * randn(size(L,2))
462-
h = 0.0001;
462+
h = 0.0001;
463463
@test diff(f)[0.1] (f[0.1+h]-f[0.1])/h
464464
end
465465

@@ -485,7 +485,7 @@ end
485485

486486
struct ChebyshevWeight <: Weight{Float64} end
487487

488-
488+
Base.:(==)(::Chebyshev, ::Chebyshev) = true
489489
Base.axes(T::Chebyshev) = (Inclusion(-1..1), Base.OneTo(T.n))
490490
ContinuumArrays.grid(T::Chebyshev) = chebyshevpoints(Float64, T.n, Val(1))
491491
Base.axes(T::ChebyshevWeight) = (Inclusion(-1..1),)
@@ -496,33 +496,49 @@ Base.getindex(::ChebyshevWeight, x::Float64) = 1/sqrt(1-x^2)
496496
LinearAlgebra.factorize(L::Chebyshev) =
497497
TransformFactorization(grid(L), plan_chebyshevtransform(Array{Float64}(undef, size(L,2))))
498498

499+
# This is wrong but just for tests
500+
Base.broadcasted(::LazyQuasiArrayStyle{2}, ::typeof(*), a::Expansion{<:Any,<:Chebyshev}, b::Chebyshev) = b * Matrix(I, 5, 5)
501+
499502
@testset "Chebyshev" begin
500503
T = Chebyshev(5)
504+
w = ChebyshevWeight()
505+
wT = w .* T
501506
x = axes(T,1)
502507
F = factorize(T)
503508
g = grid(F)
504509
@test T \ exp.(x) == F \ exp.(x) == F \ exp.(g) == chebyshevtransform(exp.(g), Val(1))
505510

506-
w = ChebyshevWeight()
507-
@test MemoryLayout(w) isa WeightLayout
508-
@test MemoryLayout(w[Inclusion(0..1)]) isa WeightLayout
509-
510-
wT = w .* T
511-
wT2 = w .* T[:,2:4]
512-
wT3 = wT[:,2:4]
513-
@test MemoryLayout(wT) == WeightedBasisLayout()
514-
@test MemoryLayout(wT2) == WeightedBasisLayout()
515-
@test MemoryLayout(wT3) == SubWeightedBasisLayout()
516-
@test grid(wT) == grid(wT2) == grid(wT3) == grid(T)
511+
@testset "Weighted" begin
512+
@test MemoryLayout(w) isa WeightLayout
513+
@test MemoryLayout(w[Inclusion(0..1)]) isa WeightLayout
517514

518-
@test ContinuumArrays.unweightedbasis(wT) T
519-
@test ContinuumArrays.unweightedbasis(wT2) T[:,2:4]
520-
@test ContinuumArrays.unweightedbasis(wT3) T[:,2:4]
515+
wT2 = w .* T[:,2:4]
516+
wT3 = wT[:,2:4]
517+
@test MemoryLayout(wT) == WeightedBasisLayout()
518+
@test MemoryLayout(wT2) == WeightedBasisLayout()
519+
@test MemoryLayout(wT3) == SubWeightedBasisLayout()
520+
@test grid(wT) == grid(wT2) == grid(wT3) == grid(T)
521521

522+
@test ContinuumArrays.unweightedbasis(wT) T
523+
@test ContinuumArrays.unweightedbasis(wT2) T[:,2:4]
524+
@test ContinuumArrays.unweightedbasis(wT3) T[:,2:4]
525+
end
522526
@testset "Mapped" begin
523527
y = affine(0..1, x)
524528
@test MemoryLayout(wT[y,:]) isa MappedWeightedBasisLayout
525529
@test MemoryLayout(w[y] .* T[y,:]) isa MappedWeightedBasisLayout
526530
@test wT[y,:][[0.1,0.2],1:5] == (w[y] .* T[y,:])[[0.1,0.2],1:5] == (w .* T[:,1:5])[y,:][[0.1,0.2],:]
527531
end
532+
533+
@testset "Broadcasted" begin
534+
a = 1 .+ x .+ x.^2
535+
# The following are wrong, just testing dispatch
536+
@test T \ (a .* T) == I
537+
@test T \ (a .* (T * (T \ a))) [2.875, 3.5, 2.0, 0.5, 0.125]
538+
f = exp.(x) .* a # another broadcast layout
539+
@test T \ f == F \ f
540+
541+
= T * (T \ a)
542+
@test T \ (ã .* ã) [1.5,1,0.5,0,0]
543+
end
528544
end

0 commit comments

Comments
 (0)