Skip to content

Derivative calls diff #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 14, 2023
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ContinuumArrays"
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
version = "0.13"
version = "0.14"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down Expand Up @@ -30,7 +30,7 @@ InfiniteArrays = "0.12"
Infinities = "0.1"
IntervalSets = "0.7"
LazyArrays = "1.0"
QuasiArrays = "0.10"
QuasiArrays = "0.11"
RecipesBase = "1.0"
StaticArrays = "1.0"
julia = "1.9"
Expand Down
6 changes: 4 additions & 2 deletions src/ContinuumArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclu
QuasiDiagonal, MulQuasiArray, MulQuasiMatrix, MulQuasiVector, QuasiMatMulMat, QuasiArrayLayout,
ApplyQuasiArray, ApplyQuasiMatrix, LazyQuasiArrayApplyStyle, AbstractQuasiArrayApplyStyle, AbstractQuasiLazyLayout,
LazyQuasiArray, LazyQuasiVector, LazyQuasiMatrix, LazyLayout, LazyQuasiArrayStyle, _factorize, _cutdim,
AbstractQuasiFill, UnionDomain, sum_size, sum_layout, _cumsum, cumsum_layout, applylayout, _equals, layout_broadcasted, PolynomialLayout, _dot
AbstractQuasiFill, UnionDomain, sum_size, sum_layout, _cumsum, cumsum_layout, applylayout, _equals, layout_broadcasted, PolynomialLayout, dot_size,
diff_layout, diff_size
import InfiniteArrays: Infinity, InfAxes
import AbstractFFTs: Plan

Expand Down Expand Up @@ -104,7 +105,8 @@ include("plotting.jl")
###

sum_size(::Tuple{InfiniteCardinal{1}}, a, dims) = _sum(expand(a), dims)
_dot(::InfiniteCardinal{1}, a, b) = dot(expand(a), expand(b))
dot_size(::InfiniteCardinal{1}, a, b) = dot(expand(a), expand(b))
diff_size(::Tuple{InfiniteCardinal{1}}, a, dims) = diff(expand(a); dims=dims)
function copy(d::Dot{<:ExpansionLayout,<:ExpansionLayout,<:AbstractQuasiArray,<:AbstractQuasiArray})
a,b = d.A,d.B
P,c = basis(a),coefficients(a)
Expand Down
119 changes: 72 additions & 47 deletions src/bases/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ abstract type Basis{T} <: LazyQuasiMatrix{T} end
abstract type Weight{T} <: LazyQuasiVector{T} end


struct WeightLayout <: AbstractQuasiLazyLayout end
abstract type AbstractWeightLayout <: AbstractQuasiLazyLayout end
struct WeightLayout <: AbstractWeightLayout end
struct MappedWeightLayout <: AbstractWeightLayout end
abstract type AbstractBasisLayout <: AbstractQuasiLazyLayout end
abstract type AbstractWeightedBasisLayout <: AbstractBasisLayout end
struct BasisLayout <: AbstractBasisLayout end
Expand All @@ -12,22 +14,22 @@ struct WeightedBasisLayout{Basis} <: AbstractWeightedBasisLayout end
const SubWeightedBasisLayout = WeightedBasisLayout{SubBasisLayout}
const MappedWeightedBasisLayout = WeightedBasisLayout{MappedBasisLayout}

struct AdjointBasisLayout{Basis} <: AbstractQuasiLazyLayout end
const AdjointSubBasisLayout = AdjointBasisLayout{SubBasisLayout}

SubBasisLayouts = Union{SubBasisLayout,SubWeightedBasisLayout}
WeightedBasisLayouts = Union{WeightedBasisLayout,SubWeightedBasisLayout,MappedWeightedBasisLayout}
MappedBasisLayouts = Union{MappedBasisLayout,MappedWeightedBasisLayout}

struct AdjointBasisLayout{Basis} <: AbstractQuasiLazyLayout end
const AdjointSubBasisLayout = AdjointBasisLayout{SubBasisLayout}
const AdjointMappedBasisLayout = AdjointBasisLayout{MappedBasisLayout}
AdjointMappedBasisLayouts = AdjointBasisLayout{<:MappedBasisLayouts}

MemoryLayout(::Type{<:Basis}) = BasisLayout()
MemoryLayout(::Type{<:Weight}) = WeightLayout()

adjointlayout(::Type, ::Basis) where Basis<:AbstractBasisLayout = AdjointBasisLayout{Basis}()
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::Basis) where Basis<:AbstractBasisLayout = WeightedBasisLayout{Basis}()
broadcastlayout(::Type{typeof(*)}, ::AbstractWeightLayout, ::Basis) where Basis<:AbstractBasisLayout = WeightedBasisLayout{Basis}()

# A sub of a weight is still a weight
sublayout(::WeightLayout, _) = WeightLayout()
sublayout(::AbstractWeightLayout, _) = WeightLayout()
sublayout(::AbstractWeightLayout, ::Type{<:Tuple{Map}}) = MappedWeightLayout()
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{Map,AbstractVector}}) = MappedBasisLayout()

# copy with an Inclusion can not be materialized
Expand All @@ -39,6 +41,7 @@ unweighted(P::BroadcastQuasiMatrix{<:Any,typeof(*),<:Tuple{AbstractQuasiVector,A
unweighted(V::SubQuasiArray) = view(unweighted(parent(V)), parentindices(V)...)
weight(P::BroadcastQuasiMatrix{<:Any,typeof(*),<:Tuple{AbstractQuasiVector,AbstractQuasiMatrix}}) = first(P.args)
weight(V::SubQuasiArray) = weight(parent(V))[parentindices(V)[1]]
weight(V::SubQuasiArray{<:Any,2,<:Any, <:Tuple{Inclusion,Any}}) = weight(parent(V))

unweighted(a::AbstractQuasiArray) = unweighted(MemoryLayout(a), a)
# Default is lazy
Expand Down Expand Up @@ -109,7 +112,7 @@ copy(L::Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVe
end

# default to transform for expanding weights
copy(L::Ldiv{<:AbstractBasisLayout,WeightLayout}) = transform_ldiv(L.A, L.B)
copy(L::Ldiv{<:AbstractBasisLayout,<:AbstractWeightLayout}) = transform_ldiv(L.A, L.B)

# multiplication operators, reexpand in basis A
@inline function _broadcast_mul_ldiv(::Tuple{Any,AbstractBasisLayout}, A, B)
Expand Down Expand Up @@ -341,6 +344,8 @@ gives a basis for expanding given quasi-vector.
basis(v) = basis_layout(MemoryLayout(v), v)

basis_layout(::ExpansionLayout, v::ApplyQuasiArray{<:Any,N,typeof(*)}) where N = v.args[1]
basis_layout(lay::ApplyLayout{typeof(*)}, v) = basis(first(arguments(lay, v)))
basis_layout(lay::AbstractBasisLayout, v) = v
basis_layout(lay, v) = basis_axes(axes(v,1), v) # allow choosing a basis based on axes
basis_axes(ax, v) = error("Overload for $ax")

Expand Down Expand Up @@ -499,43 +504,6 @@ end
# \int_a^b f(y) g(y) dy = \int_{-1}^1 f(p(x))*g(p(x)) * p'(x) dx


_sub_getindex(A, kr, jr) = A[kr, jr]
_sub_getindex(A, ::Slice, ::Slice) = A

@simplify function *(Ac::QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}},
B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}})
A = Ac'
PA,PB = parent(A),parent(B)
kr,jr = parentindices(B)
_sub_getindex((PA'PB)/kr.A,parentindices(A)[2],jr)
end


# Differentiation of sub-arrays

# avoid stack overflow from unmaterialize Derivative() * parent()
_der_sub(DP, inds...) = DP[inds...]
_der_sub(DP::ApplyQuasiMatrix{T,typeof(*),<:Tuple{Derivative,Any}}, kr, jr) where T = ApplyQuasiMatrix{T}(*, DP.args[1], view(DP.args[2], kr, jr))

# need to customise simplifiable so can't use @simplify
simplifiable(::typeof(*), A::Derivative, B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}})= simplifiable(*, Derivative(axes(parent(B),1)), parent(B))
simplifiable(::typeof(*), Ac::QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}}}, Bc::QuasiAdjoint{<:Any,<:Derivative}) = simplifiable(*, Bc', Ac')
function mul(A::Derivative, B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}})
axes(A,2) == axes(B,1) || throw(DimensionMismatch())
P = parent(B)
_der_sub(Derivative(axes(P,1))*P, parentindices(B)...)
end
mul(Ac::QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}}}, Bc::QuasiAdjoint{<:Any,<:Derivative}) = mul(Bc', Ac')'

simplifiable(::typeof(*), A::Derivative, B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}) = simplifiable(*, Derivative(axes(parent(B),1)), parent(B))
simplifiable(::typeof(*), Ac::QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}}, Bc::QuasiAdjoint{<:Any,<:Derivative}) = simplifiable(*, Bc', Ac')
function mul(A::Derivative, B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}})
axes(A,2) == axes(B,1) || throw(DimensionMismatch())
P = parent(B)
kr,jr = parentindices(B)
(Derivative(axes(P,1))*P*kr.A)[kr,jr]
end
mul(Ac::QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}}, Bc::QuasiAdjoint{<:Any,<:Derivative}) = mul(Bc', Ac')'

# we represent as a Mul with a banded matrix
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractVector}}) = SubBasisLayout()
Expand Down Expand Up @@ -600,7 +568,7 @@ end
# sum
####
function sum_layout(::SubBasisLayout, Vm, dims)
@assert dims == 1
dims == 1 || error("not implemented")
sum(parent(Vm); dims=dims)[:,parentindices(Vm)[2]]
end

Expand All @@ -616,6 +584,63 @@ end
sum_layout(::ExpansionLayout, A, dims) = sum_layout(ApplyLayout{typeof(*)}(), A, dims)
cumsum_layout(::ExpansionLayout, A, dims) = cumsum_layout(ApplyLayout{typeof(*)}(), A, dims)

###
# diff
###

function diff_layout(::SubBasisLayout, Vm, dims::Integer)
dims == 1 || error("not implemented")
diff(parent(Vm); dims=dims)[:,parentindices(Vm)[2]]
end

function diff_layout(::WeightedBasisLayout{<:SubBasisLayout}, Vm, dims::Integer)
dims == 1 || error("not implemented")
w = weight(Vm)
V = unweighted(Vm)
view(diff(w .* parent(V)), parentindices(V)...)
end

function diff_layout(::MappedBasisLayouts, V, dims)
kr = basismap(V)
@assert kr isa AbstractAffineQuasiVector
D = diff(demap(V); dims=dims)
view(basis(D), kr, :) * (kr.A*coefficients(D))
end

diff_layout(::ExpansionLayout, A, dims...) = diff_layout(ApplyLayout{typeof(*)}(), A, dims...)


####
# Gram matrix
####

simplifiable(::Mul{<:AdjointBasisLayout, <:AbstractBasisLayout}) = Val(true)
function copy(M::Mul{<:AdjointBasisLayout, <:AbstractBasisLayout})
A = (M.A)'
A == M.B && return grammatrix(A)
error("Not implemented")
end

grammatrix(A) = grammatrix_layout(MemoryLayout(A), A)
grammatrix_layout(_, A) = error("Not implemented")

function grammatrix_layout(::MappedBasisLayouts, P)
Q = demap(P)
kr = basismap(P)
@assert kr isa AbstractAffineQuasiVector
grammatrix(Q)/kr.A
end

function copy(M::Mul{<:AdjointMappedBasisLayouts, <:MappedBasisLayouts})
A = M.A'
kr = basismap(A)
@assert kr isa AbstractAffineQuasiVector
@assert kr == basismap(M.B)
demap(A)'demap(M.B) / kr.A
end



include("basisconcat.jl")
include("basiskron.jl")
include("splines.jl")
8 changes: 4 additions & 4 deletions src/bases/basisconcat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ abstract type AbstractConcatBasis{T} <: Basis{T} end

copy(S::AbstractConcatBasis) = S

@simplify function *(D::Derivative, S::AbstractConcatBasis)
axes(D,2) == axes(S,1) || throw(DimensionMismatch())
args = arguments.(Ref(ApplyLayout{typeof(*)}()), Derivative.(axes.(S.args,1)) .* S.args)
function diff(S::AbstractConcatBasis; dims::Integer)
dims == 1 || error("not implemented")
args = arguments.(Ref(ApplyLayout{typeof(*)}()), diff.(S.args; dims=dims))
all(length.(args) .== 2) || error("Not implemented")
concatbasis(S, map(first, args)...) * mortar(Diagonal([map(last, args)...]))
end
Expand Down Expand Up @@ -112,4 +112,4 @@ function QuasiArrays._getindex(::Type{IND}, A::HvcatBasis{T}, (x,j)::IND) where
end


@simplify *(D::Derivative, H::ApplyQuasiMatrix{<:Any,typeof(hcat)}) = hcat((Ref(D) .* H.args)...)
diff(H::ApplyQuasiMatrix{<:Any,typeof(hcat)}; dims::Integer) = hcat((diff.(H.args; dims=dims))...)
65 changes: 14 additions & 51 deletions src/bases/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,10 @@ grid(L::LinearSpline, n...) = L.points
## Sub-bases


## Mass matrix
function similar(AB::QMul2{<:QuasiAdjoint{<:Any,<:LinearSpline},<:LinearSpline}, ::Type{T}) where T
n = size(AB,1)
SymTridiagonal(Vector{T}(undef, n), Vector{T}(undef, n-1))
end
#
@simplify function *(Ac::QuasiAdjoint{<:Any,<:LinearSpline}, B::LinearSpline)
M = Mul(Ac, B)
copyto!(similar(M, eltype(M)), M)
end

function copyto!(dest::SymTridiagonal,
AB::QMul2{<:QuasiAdjoint{<:Any,<:LinearSpline},<:LinearSpline})
Ac,B = AB.A,AB.B
A = parent(Ac)
A.points == B.points || throw(ArgumentError())
dv,ev = dest.dv,dest.ev
## Gram matrix
function grammatrix(A::LinearSpline{T}) where T
x = A.points; n = length(x)
length(dv) == n || throw(DimensionMismatch())
dv,ev = Vector{T}(undef, n), Vector{T}(undef, n-1)

dv[1] = (x[2]-x[1])/3
@inbounds for k = 2:n-1
Expand All @@ -83,45 +68,23 @@ function copyto!(dest::SymTridiagonal,
ev[k] = (x[k+1]-x[k])/6
end

dest
SymTridiagonal(dv, ev)
end


@simplify function *(Ac::QuasiAdjoint{<:Any,<:HeavisideSpline}, B::HeavisideSpline)
A = parent(Ac)
A.points == B.points || throw(ArgumentError("Cannot multiply incompatible splines"))
Diagonal(diff(A.points))
end
grammatrix(A::HeavisideSpline) = Diagonal(diff(A.points))


## Differentiation
function copyto!(dest::MulQuasiMatrix{<:Any,<:Tuple{<:HeavisideSpline,<:Any}},
M::QMul2{<:Derivative,<:LinearSpline})
D, L = M.A, M.B
H, A = dest.args
x = H.points

axes(dest) == axes(M) || throw(DimensionMismatch("axes must be same"))
x == L.points || throw(ArgumentError("Cannot multiply incompatible splines"))
bandwidths(A) == (0,1) || throw(ArgumentError("Not implemented"))

function diff(L::LinearSpline{T}; dims::Integer=1) where T
dims == 1 || error("not implemented")
n = size(L,2)
x = L.points
D = BandedMatrix{T}(undef, (n-1,n), (0,1))
d = diff(x)
A[band(0)] .= inv.((-).(d))
A[band(1)] .= inv.(d)

dest
end

function similar(M::QMul2{<:Derivative,<:LinearSpline}, ::Type{T}) where T
D, B = M.A, M.B
n = size(B,2)
ApplyQuasiMatrix(*, HeavisideSpline{T}(B.points),
BandedMatrix{T}(undef, (n-1,n), (0,1)))
end

@simplify function *(D::Derivative, L::LinearSpline)
M = Mul(D, L)
copyto!(similar(M, eltype(M)), M)
D[band(0)] .= inv.((-).(d))
D[band(1)] .= inv.(d)
ApplyQuasiMatrix(*, HeavisideSpline{T}(x), D)
end


Expand All @@ -130,7 +93,7 @@ end
##

function _sum(A::HeavisideSpline, dims)
@assert dims == 1
dims == 1 || error("not implemented")
permutedims(diff(A.points))
end

Expand Down
22 changes: 9 additions & 13 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,19 @@ axes(D::Derivative) = (D.axis, D.axis)
==(a::Derivative, b::Derivative) = a.axis == b.axis
copy(D::Derivative) = Derivative(copy(D.axis))

function diff(d::AbstractQuasiVector)
x = axes(d,1)
Derivative(x)*d
@simplify function *(D::Derivative, B::AbstractQuasiMatrix)
T = typeof(zero(eltype(D)) * zero(eltype(B)))
diff(convert(AbstractQuasiMatrix{T}, B); dims=1)
end

function diff(A::AbstractQuasiArray; dims::Integer)
if dims == 1
Derivative(axes(A,1)) * A
else
error("Not implemented")
end
@simplify function *(D::Derivative, B::AbstractQuasiVector)
T = typeof(zero(eltype(D)) * zero(eltype(B)))
diff(convert(AbstractQuasiVector{T}, B))
end




^(D::Derivative, k::Integer) = ApplyQuasiArray(^, D, k)


Expand All @@ -153,10 +153,6 @@ const Identity{T,D} = QuasiDiagonal{T,Inclusion{T,D}}

Identity(d::Inclusion) = QuasiDiagonal(d)

@simplify *(D::Derivative, x::Inclusion) = ones(promote_type(eltype(D),eltype(x)), x)
@simplify *(D::Derivative, c::AbstractQuasiFill) = zeros(promote_type(eltype(D),eltype(c)), axes(c,1))
# @simplify *(D::Derivative, x::AbstractQuasiVector) = D * expand(x)

struct OperatorLayout <: AbstractLazyLayout end
MemoryLayout(::Type{<:Derivative}) = OperatorLayout()
# copy(M::Mul{OperatorLayout, <:ExpansionLayout}) = simplify(M)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ContinuumArrays, QuasiArrays, IntervalSets, DomainSets, FillArrays, LinearAlgebra, BandedMatrices, InfiniteArrays, Test, Base64, RecipesBase
import ContinuumArrays: ℵ₁, materialize, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, ℵ₁,
MappedBasisLayout, AdjointMappedBasisLayout, MappedWeightedBasisLayout, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
MappedBasisLayout, AdjointMappedBasisLayouts, MappedWeightedBasisLayout, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
basis, invmap, Map, checkpoints, _plotgrid, mul, plotvalues
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivStyle, MulStyle
Expand Down
Loading