Skip to content

Commit f4acc53

Browse files
authored
Support 3-arg + (#119)
* Support 3-arg + * v0.9.4 * better subbassis broadcastcasis
1 parent 28adb52 commit f4acc53

File tree

4 files changed

+36
-9
lines changed

4 files changed

+36
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.9.3"
3+
version = "0.9.4"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/bases/bases.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -323,18 +323,34 @@ function _broadcastbasis(::typeof(+), _, _, a, b)
323323
end
324324

325325
_broadcastbasis(::typeof(+), ::MappedBasisLayouts, ::MappedBasisLayouts, a, b) = broadcastbasis(+, demap(a), demap(b))[basismap(a), :]
326+
function _broadcastbasis(::typeof(+), ::SubBasisLayout, ::SubBasisLayout, a, b)
327+
kr_a,jr_a = parentindices(a)
328+
kr_b,jr_b = parentindices(b)
329+
@assert kr_a == kr_b # frist axes must match
330+
view(broadcastbasis(+, parent(a), parent(b)), kr_a, union(jr_a,jr_b))
331+
end
332+
_broadcastbasis(::typeof(+), ::SubBasisLayout, _, a, b) = broadcastbasis(+, parent(a), b)
333+
_broadcastbasis(::typeof(+), _, ::SubBasisLayout, a, b) = broadcastbasis(+, a, parent(b))
326334

327335
broadcastbasis(::typeof(+), a, b) = _broadcastbasis(+, MemoryLayout(a), MemoryLayout(b), a, b)
336+
broadcastbasis(::typeof(+), a, b, c...) = broadcastbasis(+, broadcastbasis(+, a, b), c...)
328337

329338
broadcastbasis(::typeof(-), a, b) = broadcastbasis(+, a, b)
330339

331-
for op in (:+, :-)
332-
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), f::Expansion, g::Expansion)
333-
S,c = arguments(f)
334-
T,d = arguments(g)
335-
ST = broadcastbasis($op, S, T)
336-
ST * $op((ST \ S) * c , (ST \ T) * d)
337-
end
340+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(-), f::Expansion, g::Expansion)
341+
S,c = arguments(f)
342+
T,d = arguments(g)
343+
ST = broadcastbasis(-, S, T)
344+
ST * ((ST \ S) * c - (ST \ T) * d)
345+
end
346+
347+
_plus_P_ldiv_Ps_cs(P, ::Tuple{}, ::Tuple{}) = ()
348+
_plus_P_ldiv_Ps_cs(P, Q::Tuple, cs::Tuple) = tuple((P \ first(Q)) * first(cs), _plus_P_ldiv_Ps_cs(P, tail(Q), tail(cs))...)
349+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(+), fs::Expansion...)
350+
Ps = first.(arguments.(fs))
351+
cs = last.(arguments.(fs))
352+
P = broadcastbasis(+, Ps...)
353+
P * +(_plus_P_ldiv_Ps_cs(P, Ps, cs)...) # +((Ref(P) .\ Ps .* cs)...)
338354
end
339355

340356
function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof(*), a::Expansion, f::Expansion)

src/bases/splines.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ const HeavisideSpline = Spline{0}
99
Spline{o}(pts::AbstractVector{T}) where {o,T} = Spline{o,float(T)}(pts)
1010
Spline{o}(S::Spline) where {o} = Spline{o}(S.points)
1111

12+
summary(io::IO, L::LinearSpline) = print(io, "LinearSpline($(L.points))")
13+
1214
axes(B::Spline{o}) where o =
1315
(Inclusion(first(B.points)..last(B.points)), OneTo(length(B.points)+o-1))
1416
==(A::Spline{o}, B::Spline{o}) where o = A.points == B.points

test/test_splines.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,25 @@ using ContinuumArrays, LinearAlgebra, Test
9797

9898
@testset "+" begin
9999
L = LinearSpline([1,2,3])
100-
b = L*[3,4,5] + L*[1.,2,3]
100+
b = @inferred(L*[3,4,5] + L*[1.,2,3])
101101
@test ApplyStyle(\, typeof(L), typeof(b)) == LdivStyle()
102102
@test (L\b) == [4,6,8]
103103
B = BroadcastQuasiArray(+, L, L)
104104
@test L\B == 2Eye(3)
105105

106+
@test L*[3,4,5] + L*[1.,2,3] + L*[4,5,6] == @inferred(broadcast(+, L*[3,4,5], L*[1.,2,3], L*[4,5,6])) == L*[8,11,14]
107+
106108
b = L*[3,4,5] - L*[1.,2,3]
107109
@test (L\b) == [2,2,2]
108110
B = BroadcastQuasiArray(-, L, L)
109111
@test L\B == 0Eye(3)
112+
113+
@testset "sub" begin
114+
v = ApplyQuasiArray(*, L[:,2:end], [1,2])
115+
f = L * [1,2,3]
116+
@test v + f == f + v == L*[1,3,5]
117+
@test v + v == L*[0,2,4]
118+
end
110119
end
111120
end
112121

0 commit comments

Comments
 (0)