|
| 1 | +using BlockArrays: _BlockArray, PseudoBlockArray, BlockArray, BlockMatrix, BlockVector, |
| 2 | + nblocks, Block, cumulsizes, AbstractBlockVector |
| 3 | +using BlockBandedMatrices: BandedBlockBandedMatrix, _BandedBlockBandedMatrix, |
| 4 | + blockbandwidths, subblockbandwidths, blockbandwidth, |
| 5 | + BandedBlockBandedSizes |
| 6 | +using LinearAlgebra: BLAS |
| 7 | +import LinearAlgebra |
| 8 | +using BandedMatrices: _BandedMatrix, BandedMatrix |
| 9 | +using SharedArrays |
| 10 | +using LazyArrays |
| 11 | +using Distributed: procs, remotecall_wait |
| 12 | +import Distributed |
| 13 | + |
| 14 | +import Adapt: adapt |
| 15 | + |
| 16 | +adapt(T::Type, b::BandedBlockBandedMatrix) = |
| 17 | + _BandedBlockBandedMatrix(adapt(T, b.data), b.block_sizes) |
| 18 | +adapt(T::Type{<:AbstractArray}, b::PseudoBlockArray) = |
| 19 | + PseudoBlockArray(T(b.blocks), b.block_sizes) |
| 20 | + |
| 21 | + |
| 22 | +const SharedBandedBlockBandedMatrix = |
| 23 | + BandedBlockBandedMatrix{T, PseudoBlockArray{T, 2, SharedArray{T, 2}}} where T |
| 24 | + |
| 25 | +function SharedBandedBlockBandedMatrix{T}(::UndefInitializer, |
| 26 | + bs::BandedBlockBandedSizes; |
| 27 | + kwargs...) where T |
| 28 | + Block = fieldtype(SharedBandedBlockBandedMatrix{T}, :data) |
| 29 | + Shared = fieldtype(Block, :blocks) |
| 30 | + kwargs = Dict(kwargs) |
| 31 | + init = pop!(kwargs, :init, nothing) |
| 32 | + shared = Shared(size(bs); kwargs...) |
| 33 | + result = _BandedBlockBandedMatrix(Block(shared, bs.data_block_sizes), bs) |
| 34 | + populate!(result, init) |
| 35 | + result |
| 36 | +end |
| 37 | + |
| 38 | +Distributed.procs(A::SharedBandedBlockBandedMatrix) = procs(A.data.blocks) |
| 39 | + |
| 40 | +function populate!(A::SharedBandedBlockBandedMatrix, range, block_populate!::Function) |
| 41 | + k = 1 |
| 42 | + for i in 1:nblocks(A, 1), j in max(i - A.u, 1):min(i + A.l, nblocks(A, 2)) |
| 43 | + if k in range |
| 44 | + block_populate!(view(A, Block(i, j)), i, j) |
| 45 | + end |
| 46 | + k += 1 |
| 47 | + end |
| 48 | + A |
| 49 | +end |
| 50 | + |
| 51 | + |
| 52 | +function populate!(A::SharedBandedBlockBandedMatrix, block_populate!::Function) |
| 53 | + n = nnzero(nblocks(A)..., A.l, A.u) |
| 54 | + m = length(procs(A)) |
| 55 | + @sync begin |
| 56 | + for (i, proc) in enumerate(procs(A)) |
| 57 | + start = (n ÷ m) * (i - 1) + min((n % m), i - 1) + 1 |
| 58 | + stop = (n ÷ m) * i + min((n % m), i) |
| 59 | + @async remotecall_wait(populate!, proc, A, start:stop, block_populate!) |
| 60 | + end |
| 61 | + end |
| 62 | + A |
| 63 | +end |
| 64 | + |
| 65 | +populate!(block_populate!::Function, A::SharedBandedBlockBandedMatrix) = |
| 66 | + populate!(A, block_populate!) |
| 67 | + |
| 68 | +SharedBandedBlockBandedMatrix{T}(init::Function, |
| 69 | + bs::BandedBlockBandedSizes; |
| 70 | + pids=Int[]) where T = |
| 71 | + SharedBandedBlockBandedMatrix{T}(undef, bs; pids=pids, init=init) |
| 72 | +function SharedBandedBlockBandedMatrix{T}(init::Function, |
| 73 | + dims::NTuple{2, AbstractVector{Int}}, |
| 74 | + lu::NTuple{2, Int}, λμ::NTuple{2, Int}; |
| 75 | + pids=Int[]) where T |
| 76 | + bs = BandedBlockBandedSizes(dims..., lu..., λμ...) |
| 77 | + SharedBandedBlockBandedMatrix{T}(init, bs; pids=pids) |
| 78 | +end |
| 79 | + |
| 80 | +"""Number of non-zero elements in an banded matrix""" |
| 81 | +function nnzero(n::Integer, m::Integer, l::Integer, u::Integer) |
| 82 | + result = zero(n) |
| 83 | + for i = 0:min(n, l) |
| 84 | + result += min(n - i, m) |
| 85 | + end |
| 86 | + for i = 1:min(m, u) |
| 87 | + result += min(m - i, n) |
| 88 | + end |
| 89 | + result |
| 90 | +end |
| 91 | + |
| 92 | +function LinearAlgebra.mul!(c::AbstractBlockVector{T}, |
| 93 | + A::SharedBandedBlockBandedMatrix{T}, |
| 94 | + x::AbstractBlockVector{T}) where T |
| 95 | + @assert nblocks(A, 1) == nblocks(c, 1) |
| 96 | + @assert cumulsizes(A, 1) == cumulsizes(c, 1) |
| 97 | + @assert nblocks(A, 2) == nblocks(x, 1) |
| 98 | + @assert cumulsizes(A, 2) == cumulsizes(x, 1) |
| 99 | + |
| 100 | + n = nblocks(A, 1) |
| 101 | + m = length(procs(A)) |
| 102 | + |
| 103 | + @sync for (p, proc) in enumerate(procs(A)) |
| 104 | + |
| 105 | + p > n && continue |
| 106 | + start = (n ÷ m) * (p - 1) + min((n % m), p - 1) + 1 |
| 107 | + stop = (n ÷ m) * p + min((n % m), p) |
| 108 | + |
| 109 | + @async begin |
| 110 | + remotecall_wait(proc, start:stop) do irange |
| 111 | + @inbounds for i in irange |
| 112 | + fill!(view(c, Block(i)), zero(eltype(c))) |
| 113 | + for j = max(1, i - A.l):min(nblocks(A, 2), i + A.u) |
| 114 | + c[Block(i)] .+= Mul(view(A, Block(i, j)), view(x, Block(j))) |
| 115 | + end |
| 116 | + end |
| 117 | + end |
| 118 | + end |
| 119 | + |
| 120 | + end |
| 121 | + c |
| 122 | +end |
| 123 | + |
| 124 | + |
| 125 | +using Test |
| 126 | + |
| 127 | +function testme() |
| 128 | + SBBB = SharedBandedBlockBandedMatrix |
| 129 | + @testset "shared array backend" begin |
| 130 | + |
| 131 | + @testset "Initialization" begin |
| 132 | + n, m = repeat([2], 4), repeat([3], 2) |
| 133 | + A = SBBB{Int64}((n, m), (1, 1), (1, 0)) do block, i, j |
| 134 | + block .= 0 |
| 135 | + if (i == 3) && (j == 2); block[2, 2] = 1 end |
| 136 | + end |
| 137 | + @test view(A, Block(3, 2))[2, 2] == 1 |
| 138 | + view(A, Block(3, 2))[2, 2] = 0 |
| 139 | + @test all(A .== 0) |
| 140 | + end |
| 141 | + |
| 142 | + @testset "count non-zero elements" begin |
| 143 | + for i in 1:100 |
| 144 | + n, m = rand(1:10, 2) |
| 145 | + l, u = rand(0:10, 2) |
| 146 | + A = BandedMatrix{Int8}(undef, n, m, l, u) |
| 147 | + A.data .= 1 |
| 148 | + @test sum(A) == nnzero(n, m, l, u) |
| 149 | + end |
| 150 | + end |
| 151 | + |
| 152 | + @testset "Multiplication" begin |
| 153 | + N, M = rand(1:3, 2) |
| 154 | + l, u, λ, μ = rand(0:2, 4) |
| 155 | + n, m = rand(max(l, u, λ, μ):20, N), rand(max(l, u, λ, μ):20, M) |
| 156 | + A = BandedBlockBandedMatrix{Float64}(undef, (n, m), (l, u), (λ, μ)) |
| 157 | + A.data .= rand.() |
| 158 | + x = PseudoBlockArray(Array{Float64, 1}(undef, size(A, 2)), m) |
| 159 | + x .= rand.() |
| 160 | + |
| 161 | + Ashared = adapt(SharedArray, A) |
| 162 | + @test Ashared.data.blocks isa SharedArray |
| 163 | + @test Ashared isa SharedBandedBlockBandedMatrix |
| 164 | + @test length(procs(Ashared)) == max(1, length(procs()) - 1) |
| 165 | + cshared = adapt(SharedArray, |
| 166 | + PseudoBlockArray(Array{Float64, 1}(undef, size(A, 1)), n)) |
| 167 | + @test cshared.blocks isa SharedArray |
| 168 | + cshared .= rand.() |
| 169 | + xshared = adapt(SharedArray, x) |
| 170 | + |
| 171 | + @test LinearAlgebra.mul!(cshared, Ashared, xshared) ≈ A * x |
| 172 | + end |
| 173 | + |
| 174 | + end |
| 175 | +end |
0 commit comments