Skip to content

Commit 3c7ce24

Browse files
committed
Fixed BlockDiagonal
1 parent 657b480 commit 3c7ce24

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

src/BlockBandedMatrices.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import BandedMatrices: isbanded, bandwidths, bandwidth, banded_getindex, colrang
4242
_banded_colval, _banded_rowval, _banded_nzval # for sparse
4343

4444
export BandedBlockBandedMatrix, BlockBandedMatrix, BlockSkylineMatrix, blockbandwidth, blockbandwidths,
45-
subblockbandwidth, subblockbandwidths, Ones, Zeros, Fill, Block, BlockTridiagonal, isblockbanded
45+
subblockbandwidth, subblockbandwidths, Ones, Zeros, Fill, Block, BlockDiagonal, BlockTridiagonal, isblockbanded
4646

4747

4848
const Block1 = Block{1,Int}

src/interfaceimpl.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ const BlockDiagonal{T,VT<:Matrix{T}} = BlockMatrix{T,<:Diagonal{VT}}
2525

2626
BlockDiagonal(A) = mortar(Diagonal(A))
2727

28-
function sizes_from_blocks(A::Diagonal, _)
28+
function sizes_from_blocks(A::Diagonal, _)
2929
# for k = 1:length(A.du)
3030
# size(A.du[k],1) == sz[1][k] || throw(ArgumentError("block sizes of upper diagonal inconsisent with diagonal"))
3131
# size(A.du[k],2) == sz[2][k+1] || throw(ArgumentError("block sizes of upper diagonal inconsisent with diagonal"))
3232
# size(A.dl[k],1) == sz[1][k+1] || throw(ArgumentError("block sizes of lower diagonal inconsisent with diagonal"))
3333
# size(A.dl[k],2) == sz[2][k] || throw(ArgumentError("block sizes of lower diagonal inconsisent with diagonal"))
3434
# end
35-
BlockSizes(size.(A.diag, 1), size.(A.diag,2))
36-
end
35+
(size.(A.diag, 1), size.(A.diag,2))
36+
end
3737

3838

3939
# Block Tridiagonal
4040
const BlockTridiagonal{T,VT<:Matrix{T}} = BlockMatrix{T,<:Tridiagonal{VT}}
4141

4242
BlockTridiagonal(A,B,C) = mortar(Tridiagonal(A,B,C))
4343

44-
function sizes_from_blocks(A::Tridiagonal, _)
44+
function sizes_from_blocks(A::Tridiagonal, _)
4545
# for k = 1:length(A.du)
4646
# size(A.du[k],1) == sz[1][k] || throw(ArgumentError("block sizes of upper diagonal inconsisent with diagonal"))
4747
# size(A.du[k],2) == sz[2][k+1] || throw(ArgumentError("block sizes of upper diagonal inconsisent with diagonal"))
@@ -64,11 +64,20 @@ checksquareblocks(A) = blockisequal(axes(A)...) || throw(DimensionMismatch("bloc
6464

6565
for op in (:-, :+)
6666
@eval begin
67-
function $op(A::BlockTridiagonal, λ::UniformScaling)
67+
function $op(A::BlockDiagonal, λ::UniformScaling)
68+
checksquareblocks(A)
69+
mortar(Diagonal(broadcast($op, A.blocks.diag, Ref(λ))))
70+
end
71+
function $op::UniformScaling, A::BlockDiagonal)
72+
checksquareblocks(A)
73+
mortar(Diagonal(broadcast($op, Ref(λ), A.blocks.diag)))
74+
end
75+
76+
function $op(A::BlockTridiagonal, λ::UniformScaling)
6877
checksquareblocks(A)
6978
mortar(Tridiagonal(A.blocks.dl, broadcast($op, A.blocks.d, Ref(λ)), A.blocks.du))
7079
end
71-
function $op::UniformScaling, A::BlockTridiagonal)
80+
function $op::UniformScaling, A::BlockTridiagonal)
7281
checksquareblocks(A)
7382
mortar(Tridiagonal(A.blocks.dl, broadcast($op, Ref(λ), A.blocks.d), A.blocks.du))
7483
end
@@ -87,4 +96,4 @@ function replace_in_print_matrix(A::BlockTridiagonal, i::Integer, j::Integer, s:
8796
I,J = block.(bi)
8897
i,j = blockindex.(bi)
8998
-1 Int(J-I)  1 ? s : Base.replace_with_centered_mark(s)
90-
end
99+
end

test/test_misc.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ import BandedMatrices: bandwidths, AbstractBandedMatrix, BandedStyle, bandeddata
2121
@test BandedMatrix(V) == V
2222
end
2323

24+
@testset "Block Diagonal" begin
25+
A = BlockDiagonal(fill([1 2],3))
26+
@test blockbandwidths(A) == (0,0)
27+
@test isblockbanded(A)
28+
@test A[Block(1,1)] == [1 2]
29+
@test @inferred(getblock(A,1,2)) == @inferred(A[Block(1,2)]) == [0 0]
30+
@test_throws DimensionMismatch A+I
31+
A = BlockDiagonal(fill([1 2; 1 2],3))
32+
@test A+I == I+A == mortar(Diagonal(fill([2 2; 1 3],3))) == Matrix(A) + I
33+
end
34+
2435
@testset "Block Tridiagonal" begin
2536
A = BlockTridiagonal(fill([1 2],3), fill([3 4],4), fill([4 5],3))
2637
@test blockbandwidths(A) == (1,1)

0 commit comments

Comments
 (0)