Skip to content

Commit 4d71789

Browse files
committed
better subbassis broadcastcasis
1 parent c4fcf5e commit 4d71789

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

src/bases/bases.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,12 @@ function _broadcastbasis(::typeof(+), _, _, a, b)
323323
end
324324

325325
_broadcastbasis(::typeof(+), ::MappedBasisLayouts, ::MappedBasisLayouts, a, b) = broadcastbasis(+, demap(a), demap(b))[basismap(a), :]
326-
_broadcastbasis(::typeof(+), ::SubBasisLayout, ::SubBasisLayout, a, b) = broadcastbasis(+, parent(a), parent(b))
326+
function _broadcastbasis(::typeof(+), ::SubBasisLayout, ::SubBasisLayout, a, b)
327+
kr_a,jr_a = parentindices(a)
328+
kr_b,jr_b = parentindices(b)
329+
@assert kr_a == kr_b # frist axes must match
330+
view(broadcastbasis(+, parent(a), parent(b)), kr_a, union(jr_a,jr_b))
331+
end
327332
_broadcastbasis(::typeof(+), ::SubBasisLayout, _, a, b) = broadcastbasis(+, parent(a), b)
328333
_broadcastbasis(::typeof(+), _, ::SubBasisLayout, a, b) = broadcastbasis(+, a, parent(b))
329334

src/bases/splines.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ const HeavisideSpline = Spline{0}
99
Spline{o}(pts::AbstractVector{T}) where {o,T} = Spline{o,float(T)}(pts)
1010
Spline{o}(S::Spline) where {o} = Spline{o}(S.points)
1111

12+
summary(io::IO, L::LinearSpline) = print(io, "LinearSpline($(L.points))")
13+
1214
axes(B::Spline{o}) where o =
1315
(Inclusion(first(B.points)..last(B.points)), OneTo(length(B.points)+o-1))
1416
==(A::Spline{o}, B::Spline{o}) where o = A.points == B.points

test/test_splines.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ using ContinuumArrays, LinearAlgebra, Test
109109
@test (L\b) == [2,2,2]
110110
B = BroadcastQuasiArray(-, L, L)
111111
@test L\B == 0Eye(3)
112+
113+
@testset "sub" begin
114+
v = ApplyQuasiArray(*, L[:,2:end], [1,2])
115+
f = L * [1,2,3]
116+
@test v + f == f + v == L*[1,3,5]
117+
@test v + v == L*[0,2,4]
118+
end
112119
end
113120
end
114121

0 commit comments

Comments
 (0)