Skip to content

Commit 7cd4ff6

Browse files
committed
views almost working
1 parent 0fa85ab commit 7cd4ff6

File tree

11 files changed

+75
-28
lines changed

11 files changed

+75
-28
lines changed

src/ContinuumArrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
module ContinuumArrays
22
using IntervalSets, LinearAlgebra, LazyArrays, BandedMatrices
33
import Base: @_inline_meta, axes, getindex, convert, prod, *, /, \, +, -,
4-
IndexStyle, IndexLinear
4+
IndexStyle, IndexLinear, ==
55
import Base.Broadcast: materialize
66
import LazyArrays: Mul2
77
import BandedMatrices: AbstractBandedLayout
88

99
include("QuasiArrays/QuasiArrays.jl")
1010
using .QuasiArrays
11-
import .QuasiArrays: _length, checkindex, Adjoint, Transpose
11+
import .QuasiArrays: _length, checkindex, Adjoint, Transpose, slice, QSlice
1212

1313
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative
1414

@@ -25,6 +25,7 @@ _length(::AbstractInterval) = ℵ₁
2525

2626

2727
checkindex(::Type{Bool}, inds::AbstractInterval, i::Real) = (leftendpoint(inds) <= i) & (i <= rightendpoint(inds))
28+
checkindex(::Type{Bool}, inds::AbstractInterval, i::QSlice) = i.axis inds
2829
function checkindex(::Type{Bool}, inds::AbstractInterval, I::AbstractArray)
2930
@_inline_meta
3031
b = true

src/QuasiArrays/QuasiArrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ import Base: getindex, size, axes, length, ==, isequal, iterate, CartesianIndice
66
import Base: @_inline_meta, DimOrInd, OneTo, @_propagate_inbounds_meta, @_noinline_meta,
77
DimsInteger, error_if_canonical_getindex, @propagate_inbounds, _return_type, _default_type,
88
_maybetail, tail, _getindex, _maybe_reshape, index_ndims, _unsafe_getindex,
9-
index_shape, to_shape, unsafe_length, @nloops, @ncall, Slice
10-
import Base: ViewIndex, Slice, ScalarIndex, RangeIndex
9+
index_shape, to_shape, unsafe_length, @nloops, @ncall, Slice, unalias
10+
import Base: ViewIndex, Slice, ScalarIndex, RangeIndex, view, viewindexing, ensure_indexable, index_dimsum,
11+
check_parent_index_match, reindex, _isdisjoint
1112
import Base: *, /, \, +, -, inv
1213

1314
import Base.Broadcast: materialize

src/QuasiArrays/abstractquasiarray.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,9 @@ error_if_canonical_getindex(::IndexCartesian, A::AbstractQuasiArray{T,N}, ::Vara
560560
error_if_canonical_getindex(::IndexStyle, ::AbstractQuasiArray, ::Any...) = nothing
561561

562562
## Internal definitions
563-
_getindex(::IndexStyle, A::AbstractQuasiArray, I...) =
564-
error("getindex for $(typeof(A)) with types $(typeof(I)) is not supported")
563+
function _getindex(::IndexStyle, A::AbstractQuasiArray, I...)
564+
materialize(view(A, I...))
565+
end
565566

566567
## IndexLinear Scalar indexing: canonical method is one Int
567568
_getindex(::IndexLinear, A::AbstractQuasiArray, i::Real) = (@_propagate_inbounds_meta; getindex(A, i))

src/QuasiArrays/adjtrans.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,5 @@ end
218218
function adjoint(M::Mul)
219219
Mul(reverse(adjoint.(M.factors))...)
220220
end
221+
222+
==(A::Adjoint, B::Adjoint) = parent(A) == parent(B)

src/QuasiArrays/indices.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,32 @@ to_index(I::AbstractQuasiArray{<:Union{AbstractArray, Colon}}) =
144144
throw(ArgumentError("invalid index: $I of type $(typeof(I))"))
145145

146146
LinearIndices(A::AbstractQuasiArray) = LinearIndices(axes(A))
147+
148+
149+
150+
"""
151+
QSlice(indices)
152+
153+
Represent an axis as a quasi-vector that returns itself.
154+
"""
155+
struct QSlice{T,AX} <: AbstractQuasiVector{T}
156+
axis::AX
157+
end
158+
QSlice(axis) = QSlice{eltype(axis),typeof(axis)}(axis)
159+
QSlice(S::QSlice) = S
160+
axes(S::QSlice) = (S,)
161+
unsafe_indices(S::QSlice) = (S,)
162+
axes1(S::QSlice) = S
163+
axes(S::QSlice{<:OneTo}) = (S.axis,)
164+
unsafe_indices(S::QSlice{<:OneTo}) = (S.axis,)
165+
axes1(S::QSlice{<:OneTo}) = S.axis
166+
167+
first(S::QSlice) = first(S.axis)
168+
last(S::QSlice) = last(S.axis)
169+
size(S::QSlice) = (length(S.axis),)
170+
length(S::QSlice) = length(S.axis)
171+
unsafe_length(S::QSlice) = unsafe_length(S.axis)
172+
getindex(S::QSlice, i::Real) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
173+
getindex(S::QSlice, i::AbstractVector{<:Real}) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
174+
show(io::IO, r::QSlice) = print(io, "QSlice(", r.axis, ")")
175+
iterate(S::QSlice, s...) = iterate(S.axis, s...)

src/QuasiArrays/matmul.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ QuasiMatMulQuasiMat{styleA, styleB, T, V} = QuasiArrayMulQuasiArray{styleA, styl
2828
*(A::AbstractArray, B::AbstractQuasiArray) = materialize(Mul(A,B))
2929
inv(A::AbstractQuasiArray) = materialize(Inv(A))
3030

31-
*(A::AbstractQuasiArray, B::Mul) = materialize(Mul(A, B.factors...))
32-
*(A::Mul, B::AbstractQuasiArray) = materialize(Mul(A.factors..., B))
31+
*(A::AbstractQuasiArray, B::Mul) = Mul(A, B.factors...)
32+
*(A::Mul, B::AbstractQuasiArray) = Mul(A.factors..., B)

src/QuasiArrays/multidimensional.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ end
77
@inline to_indices(A, inds, I::Tuple{Colon, Vararg{Any}}) =
88
(uncolon(inds, I), to_indices(A, _maybetail(inds), tail(I))...)
99

10-
uncolon(inds::Tuple{}, I::Tuple{Colon, Vararg{Any}}) = Slice(OneTo(1))
11-
uncolon(inds::Tuple, I::Tuple{Colon, Vararg{Any}}) = Slice(inds[1])
10+
slice(d::AbstractVector) = Slice(d)
11+
slice(d) = QSlice(d)
12+
13+
uncolon(inds::Tuple{}, I::Tuple{Colon, Vararg{Any}}) = slice(OneTo(1))
14+
uncolon(inds::Tuple, I::Tuple{Colon, Vararg{Any}}) = slice(inds[1])
1215

1316

1417
_maybe_reshape(::IndexLinear, A::AbstractQuasiArray, I...) = A

src/QuasiArrays/subarray.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ function view(A::AbstractQuasiArray, I::Vararg{Any,N}) where {N}
6969
unsafe_view(_maybe_reshape_parent(A, index_ndims(J...)), J...)
7070
end
7171

72-
function unsafe_view(A::AbstractQuasiArray, I::Vararg{ViewIndex,N}) where {N}
72+
const QViewIndex = Union{ViewIndex,AbstractQuasiArray}
73+
74+
function unsafe_view(A::AbstractQuasiArray, I::Vararg{QViewIndex,N}) where {N}
7375
@_inline_meta
7476
SubQuasiArray(A, I)
7577
end

src/bases/splines.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const HeavisideSpline{T} = Spline{0,T}
99
Spline{o}(pts::AbstractVector{T}) where {o,T} = Spline{o,float(T)}(pts)
1010

1111
axes(B::Spline{o}) where o = (first(B.points)..last(B.points), Base.OneTo(length(B.points)+o-1))
12+
==(A::Spline{o}, B::Spline{o}) where o = A.points == B.points
1213

1314
function getindex(B::LinearSpline{T}, x::Real, k::Int) where T
1415
x axes(B,1) && 1 k  size(B,2)|| throw(BoundsError())

src/operators.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ DiracDelta(axis) = DiracDelta(zero(float(eltype(axis))), axis)
1111
axes::DiracDelta) =.axis,)
1212
IndexStyle(::Type{<:DiracDelta}) = IndexLinear()
1313

14+
==(a::DiracDelta, b::DiracDelta) = a.axis == b.axis && a.x == b.x
15+
1416
function getindex::DiracDelta{T}, x::Real) where T
1517
x δ.axis || throw(BoundsError())
1618
x == δ.x ? inv(zero(T)) : zero(T)
@@ -37,3 +39,4 @@ Derivative{T}(axis::A) where {T,A} = Derivative{T,A}(axis)
3739
Derivative(axis) = Derivative{Float64}(axis)
3840

3941
axes(D::Derivative) = (D.axis, D.axis)
42+
==(a::Derivative, b::Derivative) = a.axis == b.axis

test/runtests.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, Test
2-
import ContinuumArrays: ℵ₁
1+
using ContinuumArrays, LazyArrays, IntervalSets, FillArrays, LinearAlgebra, BandedMatrices, Test
2+
import ContinuumArrays: ℵ₁, materialize
33

44
@testset "DiracDelta" begin
55
δ = DiracDelta(-1..3)
@@ -80,23 +80,27 @@ end
8080
@test fp[2.2] 2
8181
end
8282

83+
@testset "Weak Laplacian" begin
84+
H = HeavisideSpline(0:2)
85+
L = LinearSpline(0:2)
8386

87+
D = Derivative(axes(L,1))
88+
M = materialize(Mul(D',D,L))
89+
DL = D*L
90+
@test M.factors == tuple(D', (D*L).factors...)
8491

92+
@test materialize(Mul(L', D', D, L)) == materialize(L'D'*D*L) ==
93+
[1.0 -1 0; -1.0 2.0 -1.0; 0.0 -1.0 1.0]
8594

86-
L = LinearSpline([1,2,3])
87-
f = L*[1,2,4]
88-
D = Derivative(axes(L,1))
89-
90-
H = HeavisideSpline([1,2,3])
91-
92-
D*L
93-
M = (L'D')*(D*L)
94-
95-
96-
*(M.factors...)
97-
98-
95+
@test materialize(Mul(L', D', D, L)) isa BandedMatrix
96+
@test materialize(L'D'*D*L) isa BandedMatrix
9997

100-
LazyArrays.MemoryLayout(Diagonal(randn(5)))
98+
@test bandwidths(materialize(L'D'*D*L)) == (1,1)
99+
end
101100

102-
M.factors[2] * M.factors[3]
101+
@testset "Views" begin
102+
L = LinearSpline(0:2)
103+
@test view(L,0.1,1)[1] == L[0.1,1]
104+
end
105+
L = LinearSpline(0:2)
106+
L[:,1]

0 commit comments

Comments
 (0)