Skip to content

Commit 4e03466

Browse files
authored
sum for ApplyLayout(*) and AdjointBasisLayout (#90)
1 parent 74b5c8a commit 4e03466

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.7.1"
3+
version = "0.7.2"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/bases/bases.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,19 @@ function __sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, ::Colon)
404404
first(apply(*, sum(a[1]; dims=1), tail(a)...))
405405
end
406406

407+
__sum(::AdjointBasisLayout, Vm::AbstractQuasiMatrix, dims) = permutedims(sum(Vm'; dims=(isone(dims) ? 2 : 1)))
408+
409+
# sum is equivalent to hitting by ones(n) on the left or rifght
410+
function __sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiMatrix, d::Int)
411+
a = arguments(LAY, V)
412+
if d == 1
413+
*(sum(first(a); dims=1), tail(a)...)
414+
else
415+
@assert d == 2
416+
*(most(a)..., sum(last(a); dims=2))
417+
end
418+
end
419+
407420
function __sum(::MappedBasisLayouts, V::AbstractQuasiArray, dims)
408421
kr = basismap(V)
409422
@assert kr isa AbstractAffineQuasiVector

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,14 @@ end
347347
B = H[5x .+ 1,:]
348348
@test sum(B; dims=1) * [1,1,1] == [1]
349349
@test sum(H[:,1:2]; dims=1) * [1,1] == [2]
350+
@test sum(H'; dims=2) == permutedims(sum(H; dims=1))
350351

351352
u = H * randn(3)
352353
@test sum(u[5x .+ 1]) sum(view(u,5x .+ 1)) sum(u)/5
354+
355+
L = LinearSpline([1,2,3,6])
356+
D = Derivative(axes(L,1))
357+
@test sum(D*L; dims=1) sum((D*L)'; dims=2)' [-1 zeros(1,2) 1]
353358
end
354359

355360
@testset "Poisson" begin
@@ -590,4 +595,4 @@ ContinuumArrays.invmap(::InvQuadraticMap{T}) where T = QuadraticMap{T}()
590595
end
591596
end
592597

593-
include("test_basisconcat.jl")
598+
include("test_basisconcat.jl")

0 commit comments

Comments
 (0)