Skip to content

Support basis concatenation #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
name = "ContinuumArrays"
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
version = "0.4.2"
version = "0.5.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
QuasiArrays = "c4ea9172-b204-11e9-377d-29865faadc5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ArrayLayouts = "0.4.10, 0.5"
BandedMatrices = "0.15.17, 0.16"
FillArrays = "0.9.3, 0.10, 0.11"
InfiniteArrays = "0.8, 0.9"
IntervalSets = "0.4, 0.5"
LazyArrays = "0.19, 0.20"
QuasiArrays = "0.3.8"
ArrayLayouts = "0.5"
BandedMatrices = "0.16"
BlockArrays = "0.14"
FillArrays = "0.11"
InfiniteArrays = "0.9"
IntervalSets = "0.5"
LazyArrays = "0.20"
QuasiArrays = "0.4.1"
StaticArrays = "0.12, 1"
julia = "1.5"

[extras]
Expand Down
195 changes: 31 additions & 164 deletions src/ContinuumArrays.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
module ContinuumArrays
using IntervalSets, LinearAlgebra, LazyArrays, FillArrays, BandedMatrices, QuasiArrays, InfiniteArrays
using IntervalSets, LinearAlgebra, LazyArrays, FillArrays, BandedMatrices, QuasiArrays, InfiniteArrays, StaticArrays, BlockArrays
import Base: @_inline_meta, @_propagate_inbounds_meta, axes, getindex, convert, prod, *, /, \, +, -, ==, ^,
IndexStyle, IndexLinear, ==, OneTo, tail, similar, copyto!, copy, diff,
IndexStyle, IndexLinear, ==, OneTo, _maybetail, tail, similar, copyto!, copy, diff,
first, last, show, isempty, findfirst, findlast, findall, Slice, union, minimum, maximum, sum, _sum,
getproperty, isone, iszero, zero, abs, <, ≤, >, ≥, string, summary
getproperty, isone, iszero, zero, abs, <, ≤, >, ≥, string, summary, to_indices
import Base.Broadcast: materialize, BroadcastStyle, broadcasted
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport, most, combine_mul_styles, AbstractArrayApplyStyle,
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex, UnknownLayout,
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles, applylayout,
simplifiable, _simplify
import LinearAlgebra: pinv, dot, norm2
import BandedMatrices: AbstractBandedLayout, _BandedMatrix
import BlockArrays: block, blockindex, unblock, blockedrange, _BlockedUnitRange, _BlockArray
import FillArrays: AbstractFill, getindex_value, SquareEye
import ArrayLayouts: mul
import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclusion, SubQuasiArray,
QuasiDiagonal, MulQuasiArray, MulQuasiMatrix, MulQuasiVector, QuasiMatMulMat,
ApplyQuasiArray, ApplyQuasiMatrix, LazyQuasiArrayApplyStyle, AbstractQuasiArrayApplyStyle, AbstractQuasiLazyLayout,
LazyQuasiArray, LazyQuasiVector, LazyQuasiMatrix, LazyLayout, LazyQuasiArrayStyle, _factorize
import InfiniteArrays: Infinity
import InfiniteArrays: Infinity, InfAxes

export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, ℵ₁, Inclusion, Basis, WeightedBasis, grid, transform, affine
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, ℵ₁, Inclusion, Basis, WeightedBasis, grid, transform, affine, ..

####
# Interval indexing support
Expand Down Expand Up @@ -90,177 +91,43 @@ function dot(x::Inclusion{T,<:AbstractInterval}, y::Inclusion{V,<:AbstractInterv
end


function checkindex(::Type{Bool}, inds::Inclusion{<:Any,<:AbstractInterval}, r::Inclusion{<:Any,<:AbstractInterval})
@_propagate_inbounds_meta
isempty(r) | (checkindex(Bool, inds, first(r)) & checkindex(Bool, inds, last(r)))
end


###
# Maps
###

"""
A subtype of `Map` is used as a one-to-one map between two domains
via `view`. The domain of the map `m` is `axes(m,1)` and the range
is `union(m)`.

Maps must also overload `invmap` to give the inverse of the map, which
is equivalent to `invmap(m)[x] == findfirst(isequal(x), m)`.
"""

abstract type Map{T} <: AbstractQuasiVector{T} end

invmap(M::Map) = error("Overload invmap(::$(typeof(M)))")


Base.in(x, m::Map) = x in union(m)
Base.issubset(d::Map, b::IntervalSets.Domain) = union(d) ⊆ b
Base.union(d::Map) = axes(invmap(d),1)

for find in (:findfirst, :findlast)
@eval function $find(f::Base.Fix2{typeof(isequal)}, d::Map)
f.x in d || return nothing
$find(isequal(invmap(d)[f.x]), union(d))
end
end

@eval function findall(f::Base.Fix2{typeof(isequal)}, d::Map)
f.x in d || return eltype(axes(d,1))[]
findall(isequal(invmap(d)[f.x]), union(d))
end

function Base.getindex(d::Map, x::Inclusion)
x == axes(d,1) || throw(BoundsError(d, x))
d
end

# Affine map represents A*x .+ b
abstract type AbstractAffineQuasiVector{T,AA,X,B} <: Map{T} end

summary(io::IO, a::AbstractAffineQuasiVector) = print(io, "$(a.A) * $(a.x) .+ ($(a.b))")

struct AffineQuasiVector{T,AA,X,B} <: AbstractAffineQuasiVector{T,AA,X,B}
A::AA
x::X
b::B
end

AffineQuasiVector(A::AA, x::X, b::B) where {AA,X,B} =
AffineQuasiVector{promote_type(eltype(AA), eltype(X), eltype(B)),AA,X,B}(A,x,b)

AffineQuasiVector(A, x) = AffineQuasiVector(A, x, zero(promote_type(eltype(A),eltype(x))))
AffineQuasiVector(x) = AffineQuasiVector(one(eltype(x)), x)

AffineQuasiVector(A, x::AffineQuasiVector, b) = AffineQuasiVector(A*x.A, x.x, A*x.b .+ b)

axes(A::AbstractAffineQuasiVector) = axes(A.x)

affine_getindex(A, k) = A.A*A.x[k] .+ A.b
Base.unsafe_getindex(A::AbstractAffineQuasiVector, k) = A.A*Base.unsafe_getindex(A.x,k) .+ A.b
getindex(A::AbstractAffineQuasiVector, k::Number) = affine_getindex(A, k)
function getindex(A::AbstractAffineQuasiVector, k::Inclusion)
@boundscheck A.x[k] # throws bounds error if k ≠ x
A
end

getindex(A::AbstractAffineQuasiVector, ::Colon) = copy(A)

copy(A::AbstractAffineQuasiVector) = A

inbounds_getindex(A::AbstractAffineQuasiVector{<:Any,<:Any,<:Inclusion}, k::Number) = A.A*k .+ A.b
isempty(A::AbstractAffineQuasiVector) = isempty(A.x)
==(a::AbstractAffineQuasiVector, b::AbstractAffineQuasiVector) = a.A == b.A && a.x == b.x && a.b == b.b

BroadcastStyle(::Type{<:AbstractAffineQuasiVector}) = LazyQuasiArrayStyle{1}()

for op in(:*, :\, :+, :-)
@eval broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), a::Number, x::Inclusion) = broadcast($op, a, AffineQuasiVector(x))
end
for op in(:/, :+, :-)
@eval broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), x::Inclusion, a::Number) = broadcast($op, AffineQuasiVector(x), a)
end

broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(*), a::Number, x::AbstractAffineQuasiVector) = AffineQuasiVector(a, x)
broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(\), a::Number, x::AbstractAffineQuasiVector) = AffineQuasiVector(inv(a), x)
broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(/), x::AbstractAffineQuasiVector, a::Number) = AffineQuasiVector(inv(a), x)
broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(+), a::Number, x::AbstractAffineQuasiVector) = AffineQuasiVector(one(eltype(x)), x, a)
broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(+), x::AbstractAffineQuasiVector, b::Number) = AffineQuasiVector(one(eltype(x)), x, b)
broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(-), a::Number, x::AbstractAffineQuasiVector) = AffineQuasiVector(-one(eltype(x)), x, a)
broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(-), x::AbstractAffineQuasiVector, b::Number) = AffineQuasiVector(one(eltype(x)), x, -b)

function checkindex(::Type{Bool}, inds::Inclusion{<:Any,<:AbstractInterval}, r::AbstractAffineQuasiVector)
@_propagate_inbounds_meta
isempty(r) | (checkindex(Bool, inds, first(r)) & checkindex(Bool, inds, last(r)))
end

minimum(d::AbstractAffineQuasiVector) = signbit(d.A) ? last(d) : first(d)
maximum(d::AbstractAffineQuasiVector) = signbit(d.A) ? first(d) : last(d)

union(d::AbstractAffineQuasiVector) = Inclusion(minimum(d)..maximum(d))
invmap(d::AbstractAffineQuasiVector) = affine(union(d), axes(d,1))



include("maps.jl")

const QInfAxes = Union{Inclusion,AbstractAffineQuasiVector}

struct AffineMap{T,D,R} <: AbstractAffineQuasiVector{T,T,D,T}
domain::D
range::R
end

AffineMap(domain::AbstractQuasiVector{T}, range::AbstractQuasiVector{V}) where {T,V} =
AffineMap{promote_type(T,V), typeof(domain),typeof(range)}(domain,range)
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes,QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{Any,QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes,Any}) = V

measure(x::Inclusion) = last(x)-first(x)
# ambiguity error
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{InfAxes,QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes,InfAxes}) = V

function getproperty(A::AffineMap, d::Symbol)
domain, range = getfield(A, :domain), getfield(A, :range)
d == :x && return domain
d == :A && return measure(range)/measure(domain)
d == :b && return (last(domain)*first(range) - first(domain)*last(range))/measure(domain)
getfield(A, d)
end
#
# BlockQuasiArrays

function getindex(A::AffineMap, k::Number)
# ensure we exactly hit range
k == first(A.domain) && return first(A.range)
k == last(A.domain) && return last(A.range)
affine_getindex(A, k)
BlockArrays.blockaxes(::Inclusion) = blockaxes(Base.OneTo(1)) # just use 1 block
function BlockArrays.blockaxes(A::AbstractQuasiArray{T,N}, d) where {T,N}
@_inline_meta
d::Integer <= N ? blockaxes(A)[d] : Base.OneTo(1)
end

@inline to_indices(A::AbstractQuasiArray, inds, I::Tuple{Block{1}, Vararg{Any}}) =
(unblock(A, inds, I), to_indices(A, _maybetail(inds), tail(I))...)
@inline to_indices(A::AbstractQuasiArray, inds, I::Tuple{BlockRange{1,R}, Vararg{Any}}) where R =
(unblock(A, inds, I), to_indices(A, _maybetail(inds), tail(I))...)
@inline to_indices(A::AbstractQuasiArray, inds, I::Tuple{BlockIndex{1}, Vararg{Any}}) =
(inds[1][I[1]], to_indices(A, _maybetail(inds), tail(I))...)

first(A::AffineMap) = first(A.range)
last(A::AffineMap) = last(A.range)

affine(a::AbstractQuasiVector, b::AbstractQuasiVector) = AffineMap(a, b)
affine(a, b::AbstractQuasiVector) = affine(Inclusion(a), b)
affine(a::AbstractQuasiVector, b) = affine(a, Inclusion(b))
affine(a, b) = affine(Inclusion(a), Inclusion(b))


# mapped vectors
const AffineMappedQuasiVector = SubQuasiArray{<:Any, 1, <:Any, <:Tuple{AbstractAffineQuasiVector}}
const AffineMappedQuasiMatrix = SubQuasiArray{<:Any, 2, <:Any, <:Tuple{AbstractAffineQuasiVector,Slice}}

==(a::AffineMappedQuasiVector, b::AffineMappedQuasiVector) = parentindices(a) == parentindices(b) && parent(a) == parent(b)

_sum(V::AffineMappedQuasiVector, ::Colon) = parentindices(V)[1].A \ sum(parent(V))

# pretty print for bases
summary(io::IO, P::AffineMappedQuasiMatrix) = print(io, "$(parent(P)) affine mapped to $(parentindices(P)[1].x.domain)")
summary(io::IO, P::AffineMappedQuasiVector) = print(io, "$(parent(P)) affine mapped to $(parentindices(P)[1].x.domain)")

const QInfAxes = Union{Inclusion,AbstractAffineQuasiVector}


sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes,QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{Any,QInfAxes}) = V
sub_materialize(_, V::AbstractQuasiArray, ::Tuple{QInfAxes,Any}) = V
checkpoints(d::AbstractInterval{T}) where T = width(d) .* SVector{3,float(T)}(0.823972,0.01,0.3273484) .+ leftendpoint(d)
checkpoints(x::Inclusion) = checkpoints(x.domain)
checkpoints(A::AbstractQuasiMatrix) = checkpoints(axes(A,1))


include("operators.jl")
include("bases/bases.jl")
include("basisconcat.jl")

end
6 changes: 2 additions & 4 deletions src/bases/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ end
@inline function copy(P::Ldiv{<:AbstractBasisLayout,<:AbstractBasisLayout})
A, B = P.A, P.B
A == B || throw(ArgumentError("Override copy for $(typeof(A)) \\ $(typeof(B))"))
SquareEye{eltype(P)}((axes(A,2),))
SquareEye{eltype(eltype(P))}((axes(A,2),)) # use double eltype for array-valued
end
@inline function copy(P::Ldiv{<:SubBasisLayouts,<:SubBasisLayouts})
A, B = P.A, P.B
parent(A) == parent(B) ||
throw(ArgumentError("Override copy for $(typeof(A)) \\ $(typeof(B))"))
Eye{eltype(P)}((axes(A,2),axes(B,2)))
Eye{eltype(eltype(P))}((axes(A,2),axes(B,2)))
end

@inline function copy(P::Ldiv{<:MappedBasisLayouts,<:MappedBasisLayouts})
Expand Down Expand Up @@ -119,8 +119,6 @@ _grid(::SubBasisLayout, P) = grid(parent(P))
_grid(::WeightedBasisLayouts, P) = grid(unweightedbasis(P))
grid(P) = _grid(MemoryLayout(typeof(P)), P)

# TODO: Move over from OrthogonalPolynomialsQuasi
function checkpoints end

struct TransformFactorization{T,Grid,Plan,IPlan} <: Factorization{T}
grid::Grid
Expand Down
Loading