Skip to content

Add support for broadcasted expansions #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ContinuumArrays"
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
version = "0.3.2"
version = "0.3.3"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
2 changes: 1 addition & 1 deletion src/ContinuumArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: @_inline_meta, @_propagate_inbounds_meta, axes, getindex, convert,
getproperty, isone, iszero, zero, abs, <, ≤, >, ≥, string
import Base.Broadcast: materialize, BroadcastStyle, broadcasted
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport, most, combine_mul_styles, AbstractArrayApplyStyle,
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex,
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex, UnknownLayout,
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles, applylayout,
simplifiable, _simplify
import LinearAlgebra: pinv, dot, norm2
Expand Down
22 changes: 22 additions & 0 deletions src/bases/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ for Bas1 in (:Basis, :WeightedBasis), Bas2 in (:Basis, :WeightedBasis)
end


# multiplication operators, reexpand in basis A
@inline function _broadcast_mul_ldiv(::Tuple{Any,AbstractBasisLayout}, A, B)
a,b = arguments(B)
@assert a isa AbstractQuasiVector # Only works for vec .* mat
ab = (A * (A \ a)) .* b # broadcasted should be overloaded
MemoryLayout(ab) isa BroadcastLayout && error("Overload broadcasted(_, ::typeof(*), ::$(typeof(ab.args[1])), ::$(typeof(b)))")
A \ ab
end

_broadcast_mul_ldiv(_, A, B) = copy(Ldiv{typeof(MemoryLayout(A)),UnknownLayout}(A,B))

copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)
copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector}) = _broadcast_mul_ldiv(map(MemoryLayout,arguments(L.B)), L.A, L.B)


# expansion
_grid(_, P) = error("Overload Grid")
_grid(::MappedBasisLayout, P) = igetindex.(Ref(parentindices(P)[1]), grid(demap(P)))
Expand Down Expand Up @@ -203,6 +218,13 @@ for op in (:+, :-)
end
end

function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(*), a::Expansion, f::Expansion)
axes(a,1) == axes(f,1) || throw(DimensionMismatch())
P,c = arguments(f)
(a .* P) * c
end


@eval function ==(f::Expansion, g::Expansion)
S,c = arguments(f)
T,d = arguments(g)
Expand Down
54 changes: 35 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ContinuumArrays, QuasiArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, FastTransforms, InfiniteArrays, Test, Base64
import ContinuumArrays: ℵ₁, materialize, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, ℵ₁,
MappedBasisLayout, AdjointMappedBasisLayout, MappedWeightedBasisLayout, igetindex, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
MappedBasisLayout, AdjointMappedBasisLayout, MappedWeightedBasisLayout, igetindex, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
Expansion, basis
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivStyle, MulStyle
Expand Down Expand Up @@ -236,7 +236,7 @@ end
L = LinearSpline([1,2,3])
f = L*[1,2,4]
g = L*[5,6,7]

@test f isa Expansion
@test 2f isa Expansion
@test f*2 isa Expansion
Expand Down Expand Up @@ -346,7 +346,7 @@ end
@testset "sub-colon" begin
L = LinearSpline([1,2,3])
@test L[:,1][1.1] == (L')[1,:][1.1] == L[1.1,1]
@test L[:,1:2][1.1,:] == (L')[1:2,:][:,1.1] == L[1.1,1:2]
@test L[:,1:2][1.1,:] == (L')[1:2,:][:,1.1] == L[1.1,1:2]
end

@testset "transform" begin
Expand Down Expand Up @@ -459,7 +459,7 @@ end
@testset "diff" begin
L = LinearSpline(range(-1,stop=1,length=10))
f = L * randn(size(L,2))
h = 0.0001;
h = 0.0001;
@test diff(f)[0.1] ≈ (f[0.1+h]-f[0.1])/h
end

Expand All @@ -485,7 +485,7 @@ end

struct ChebyshevWeight <: Weight{Float64} end


Base.:(==)(::Chebyshev, ::Chebyshev) = true
Base.axes(T::Chebyshev) = (Inclusion(-1..1), Base.OneTo(T.n))
ContinuumArrays.grid(T::Chebyshev) = chebyshevpoints(Float64, T.n, Val(1))
Base.axes(T::ChebyshevWeight) = (Inclusion(-1..1),)
Expand All @@ -496,33 +496,49 @@ Base.getindex(::ChebyshevWeight, x::Float64) = 1/sqrt(1-x^2)
LinearAlgebra.factorize(L::Chebyshev) =
TransformFactorization(grid(L), plan_chebyshevtransform(Array{Float64}(undef, size(L,2))))

# This is wrong but just for tests
Base.broadcasted(::LazyQuasiArrayStyle{2}, ::typeof(*), a::Expansion{<:Any,<:Chebyshev}, b::Chebyshev) = b * Matrix(I, 5, 5)

@testset "Chebyshev" begin
T = Chebyshev(5)
w = ChebyshevWeight()
wT = w .* T
x = axes(T,1)
F = factorize(T)
g = grid(F)
@test T \ exp.(x) == F \ exp.(x) == F \ exp.(g) == chebyshevtransform(exp.(g), Val(1))

w = ChebyshevWeight()
@test MemoryLayout(w) isa WeightLayout
@test MemoryLayout(w[Inclusion(0..1)]) isa WeightLayout

wT = w .* T
wT2 = w .* T[:,2:4]
wT3 = wT[:,2:4]
@test MemoryLayout(wT) == WeightedBasisLayout()
@test MemoryLayout(wT2) == WeightedBasisLayout()
@test MemoryLayout(wT3) == SubWeightedBasisLayout()
@test grid(wT) == grid(wT2) == grid(wT3) == grid(T)
@testset "Weighted" begin
@test MemoryLayout(w) isa WeightLayout
@test MemoryLayout(w[Inclusion(0..1)]) isa WeightLayout

@test ContinuumArrays.unweightedbasis(wT) ≡ T
@test ContinuumArrays.unweightedbasis(wT2) ≡ T[:,2:4]
@test ContinuumArrays.unweightedbasis(wT3) ≡ T[:,2:4]
wT2 = w .* T[:,2:4]
wT3 = wT[:,2:4]
@test MemoryLayout(wT) == WeightedBasisLayout()
@test MemoryLayout(wT2) == WeightedBasisLayout()
@test MemoryLayout(wT3) == SubWeightedBasisLayout()
@test grid(wT) == grid(wT2) == grid(wT3) == grid(T)

@test ContinuumArrays.unweightedbasis(wT) ≡ T
@test ContinuumArrays.unweightedbasis(wT2) ≡ T[:,2:4]
@test ContinuumArrays.unweightedbasis(wT3) ≡ T[:,2:4]
end
@testset "Mapped" begin
y = affine(0..1, x)
@test MemoryLayout(wT[y,:]) isa MappedWeightedBasisLayout
@test MemoryLayout(w[y] .* T[y,:]) isa MappedWeightedBasisLayout
@test wT[y,:][[0.1,0.2],1:5] == (w[y] .* T[y,:])[[0.1,0.2],1:5] == (w .* T[:,1:5])[y,:][[0.1,0.2],:]
end

@testset "Broadcasted" begin
a = 1 .+ x .+ x.^2
# The following are wrong, just testing dispatch
@test T \ (a .* T) == I
@test T \ (a .* (T * (T \ a))) ≈ [2.875, 3.5, 2.0, 0.5, 0.125]
f = exp.(x) .* a # another broadcast layout
@test T \ f == F \ f

ã = T * (T \ a)
@test T \ (ã .* ã) ≈ [1.5,1,0.5,0,0]
end
end