|
| 1 | +blockrowsupport(_, A, k) = blockaxes(A,2) |
| 2 | +"""" |
| 3 | + blockrowsupport(A, k) |
| 4 | +
|
| 5 | +gives an iterator containing the possible non-zero blocks in the k-th block-row of A. |
| 6 | +""" |
| 7 | +blockrowsupport(A, k) = blockrowsupport(MemoryLayout(A), A, k) |
| 8 | +blockrowsupport(A) = blockrowsupport(A, blockaxes(A,1)) |
| 9 | + |
| 10 | +blockcolsupport(_, A, j) = blockaxes(A,1) |
| 11 | + |
| 12 | +"""" |
| 13 | + blockcolsupport(A, j) |
| 14 | +
|
| 15 | +gives an iterator containing the possible non-zero blocks in the j-th block-column of A. |
| 16 | +""" |
| 17 | +blockcolsupport(A, j) = blockcolsupport(MemoryLayout(A), A, j) |
| 18 | +blockcolsupport(A) = blockcolsupport(A, blockaxes(A,2)) |
| 19 | + |
| 20 | +blockcolstart(A...) = first(blockcolsupport(A...)) |
| 21 | +blockcolstop(A...) = last(blockcolsupport(A...)) |
| 22 | +blockrowstart(A...) = first(blockrowsupport(A...)) |
| 23 | +blockrowstop(A...) = last(blockrowsupport(A...)) |
| 24 | + |
| 25 | +for Func in (:blockcolstart, :blockcolstop, :blockrowstart, :blockrowstop) |
| 26 | + @eval $Func(A, i::Block{1}) = $Func(A, Int(i)) |
| 27 | +end |
| 28 | + |
| 29 | + |
| 30 | +abstract type AbstractBlockLayout <: MemoryLayout end |
| 31 | +struct BlockLayout{LAY} <: AbstractBlockLayout end |
| 32 | + |
| 33 | + |
| 34 | +similar(M::MulAdd{<:AbstractBlockLayout,<:AbstractBlockLayout}, ::Type{T}, axes) where {T,N} = |
| 35 | + similar(BlockArray{T}, axes) |
| 36 | + |
| 37 | +MemoryLayout(::Type{<:PseudoBlockArray{T,N,R}}) where {T,N,R} = MemoryLayout(R) |
| 38 | +MemoryLayout(::Type{<:BlockArray{T,N,R}}) where {T,N,R} = BlockLayout{typeof(MemoryLayout(R))}() |
| 39 | + |
| 40 | +sublayout(::BlockLayout{LAY}, ::Type{NTuple{N,BlockSlice1}}) where {LAY,N} = LAY() |
| 41 | +sublayout(BL::BlockLayout, ::Type{<:NTuple{N,BlockSlice}}) where N = BL |
| 42 | + |
| 43 | +conjlayout(::Type{T}, ::BlockLayout{LAY}) where {T<:Complex,LAY} = BlockLayout{typeof(conjlayout(T,LAY()))}() |
| 44 | +conjlayout(::Type{T}, ::BlockLayout{LAY}) where {T<:Real,LAY} = BlockLayout{LAY}() |
| 45 | + |
| 46 | +transposelayout(::BlockLayout{LAY}) where LAY = BlockLayout{typeof(transposelayout(LAY()))}() |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | +############# |
| 51 | +# BLAS overrides |
| 52 | +############# |
| 53 | + |
| 54 | + |
| 55 | +function materialize!(M::MatMulVecAdd{<:AbstractBlockLayout,<:AbstractStridedLayout,<:AbstractStridedLayout}) |
| 56 | + α, A, x_in, β, y_in = M.α, M.A, M.B, M.β, M.C |
| 57 | + if length(x_in) != size(A,2) || length(y_in) != size(A,1) |
| 58 | + throw(DimensionMismatch()) |
| 59 | + end |
| 60 | + |
| 61 | + # impose block structure |
| 62 | + y = PseudoBlockArray(y_in, (axes(A,1),)) |
| 63 | + x = PseudoBlockArray(x_in, (axes(A,2),)) |
| 64 | + |
| 65 | + _fill_lmul!(β, y) |
| 66 | + |
| 67 | + for J = blockaxes(A,2) |
| 68 | + for K = blockcolsupport(A,J) |
| 69 | + muladd!(α, view(A,K,J), view(x,J), one(α), view(y,K)) |
| 70 | + end |
| 71 | + end |
| 72 | + y_in |
| 73 | +end |
| 74 | + |
| 75 | +function _block_muladd!(α, A, X, β, Y) |
| 76 | + _fill_lmul!(β, Y) |
| 77 | + for J = blockaxes(X,2), N = blockcolsupport(X,J), K = blockcolsupport(A,N) |
| 78 | + muladd!(α, view(A,K,N), view(X,N,J), one(α), view(Y,K,J)) |
| 79 | + end |
| 80 | + Y |
| 81 | +end |
| 82 | + |
| 83 | +mul_blockscompatible(A, B, C) = blockisequal(axes(A,2), axes(B,1)) && |
| 84 | + blockisequal(axes(A,1), axes(C,1)) && |
| 85 | + blockisequal(axes(B,2), axes(C,2)) |
| 86 | + |
| 87 | +function materialize!(M::MatMulMatAdd{<:AbstractBlockLayout,<:AbstractBlockLayout,<:AbstractBlockLayout}) |
| 88 | + α, A, B, β, C = M.α, M.A, M.B, M.β, M.C |
| 89 | + if mul_blockscompatible(A,B,C) |
| 90 | + _block_muladd!(α, A, B, β, C) |
| 91 | + else # use default |
| 92 | + materialize!(MulAdd{UnknownLayout,UnknownLayout,UnknownLayout}(α, A, B, β, C)) |
| 93 | + end |
| 94 | +end |
| 95 | + |
| 96 | +function materialize!(M::MatMulMatAdd{<:AbstractBlockLayout,<:AbstractBlockLayout,<:AbstractColumnMajor}) |
| 97 | + α, A, X, β, Y_in = M.α, M.A, M.B, M.β, M.C |
| 98 | + Y = PseudoBlockArray(Y_in, (axes(A,1), axes(Y_in,2))) |
| 99 | + _block_muladd!(α, A, X, β, Y) |
| 100 | + Y_in |
| 101 | +end |
| 102 | + |
| 103 | +function materialize!(M::MatMulMatAdd{<:AbstractBlockLayout,<:AbstractColumnMajor,<:AbstractColumnMajor}) |
| 104 | + α, A, X_in, β, Y_in = M.α, M.A, M.B, M.β, M.C |
| 105 | + _fill_lmul!(β, Y_in) |
| 106 | + X = PseudoBlockArray(X_in, (axes(A,2), axes(X_in,2))) |
| 107 | + Y = PseudoBlockArray(Y_in, (axes(A,1), axes(Y_in,2))) |
| 108 | + _block_muladd!(α, A, X, β, Y) |
| 109 | + Y_in |
| 110 | +end |
| 111 | + |
| 112 | +function materialize!(M::MatMulMatAdd{<:AbstractColumnMajor,<:AbstractBlockLayout,<:AbstractColumnMajor}) |
| 113 | + α, A_in, X, β, Y_in = M.α, M.A, M.B, M.β, M.C |
| 114 | + _fill_lmul!(β, Y_in) |
| 115 | + A = PseudoBlockArray(A_in, (axes(A_in,1),axes(X,1))) |
| 116 | + Y = PseudoBlockArray(Y_in, (axes(Y_in,1),axes(X,2))) |
| 117 | + _block_muladd!(α, A, X, β, Y) |
| 118 | + Y_in |
| 119 | +end |
| 120 | + |
| 121 | + |
| 122 | + |
| 123 | +#### |
| 124 | +# Triangular |
| 125 | +#### |
| 126 | + |
| 127 | +@inline hasmatchingblocks(A) = blockisequal(axes(A)...) |
| 128 | + |
| 129 | +triangularlayout(::Type{Tri}, ::ML) where {Tri,ML<:AbstractBlockLayout} = Tri{ML}() |
| 130 | + |
| 131 | +_triangular_matrix(::Val{'U'}, ::Val{'N'}, A) = UpperTriangular(A) |
| 132 | +_triangular_matrix(::Val{'L'}, ::Val{'N'}, A) = LowerTriangular(A) |
| 133 | +_triangular_matrix(::Val{'U'}, ::Val{'U'}, A) = UnitUpperTriangular(A) |
| 134 | +_triangular_matrix(::Val{'L'}, ::Val{'U'}, A) = UnitLowerTriangular(A) |
| 135 | + |
| 136 | + |
| 137 | +function _matchingblocks_triangular_mul!(::Val{'U'}, UNIT, A::AbstractMatrix{T}, dest) where T |
| 138 | + # impose block structure |
| 139 | + b = PseudoBlockArray(dest, (axes(A,1),)) |
| 140 | + |
| 141 | + for K = blockaxes(A,1) |
| 142 | + b_2 = view(b, K) |
| 143 | + Ũ = _triangular_matrix(Val('U'), UNIT, view(A, K,K)) |
| 144 | + materialize!(Lmul(Ũ, b_2)) |
| 145 | + JR = (K+1):last(blockrowsupport(A,K)) |
| 146 | + if !isempty(JR) |
| 147 | + muladd!(one(T), view(A, K, JR), view(b,JR), one(T), b_2) |
| 148 | + end |
| 149 | + end |
| 150 | + dest |
| 151 | +end |
| 152 | + |
| 153 | +function _matchingblocks_triangular_mul!(::Val{'L'}, UNIT, A::AbstractMatrix{T}, dest) where T |
| 154 | + # impose block structure |
| 155 | + b = PseudoBlockArray(dest, (axes(A,1),)) |
| 156 | + |
| 157 | + N = blocksize(A,1) |
| 158 | + |
| 159 | + for K = N:-1:1 |
| 160 | + b_2 = view(b, Block(K)) |
| 161 | + L̃ = _triangular_matrix(Val('L'), UNIT, view(A, Block(K,K))) |
| 162 | + materialize!(Lmul(L̃, b_2)) |
| 163 | + JR = blockrowstart(A,K):Block(K-1) |
| 164 | + if !isempty(JR) |
| 165 | + muladd!(one(T), view(A, Block(K), JR), view(b,JR), one(T), b_2) |
| 166 | + end |
| 167 | + end |
| 168 | + |
| 169 | + dest |
| 170 | +end |
| 171 | + |
| 172 | +@inline function materialize!(M::MatLmulVec{<:TriangularLayout{UPLO,UNIT,<:AbstractBlockLayout}, |
| 173 | + <:AbstractStridedLayout}) where {UPLO,UNIT} |
| 174 | + U,x = M.A,M.B |
| 175 | + T = eltype(M) |
| 176 | + @boundscheck size(U,1) == size(x,1) || throw(BoundsError()) |
| 177 | + if hasmatchingblocks(U) |
| 178 | + _matchingblocks_triangular_mul!(Val(UPLO), Val(UNIT), triangulardata(U), x) |
| 179 | + else # use default |
| 180 | + x_1 = PseudoBlockArray(copy(x), (axes(U,2),)) |
| 181 | + x_2 = PseudoBlockArray(x, (axes(U,1),)) |
| 182 | + _block_muladd!(one(T), U, x_1, zero(T), x_2) |
| 183 | + end |
| 184 | +end |
| 185 | + |
| 186 | + |
| 187 | +for UNIT in ('U', 'N') |
| 188 | + @eval begin |
| 189 | + @inline function materialize!(M::MatLdivVec{<:TriangularLayout{'U',$UNIT,<:AbstractBlockLayout}, |
| 190 | + <:AbstractStridedLayout}) |
| 191 | + U,dest = M.A,M.B |
| 192 | + T = eltype(dest) |
| 193 | + |
| 194 | + A = triangulardata(U) |
| 195 | + if !hasmatchingblocks(A) # Use default for now |
| 196 | + return materialize!(Ldiv{TriangularLayout{'U',$UNIT,UnknownLayout}, |
| 197 | + typeof(MemoryLayout(dest))}(U, dest)) |
| 198 | + end |
| 199 | + |
| 200 | + @boundscheck size(A,1) == size(dest,1) || throw(BoundsError()) |
| 201 | + |
| 202 | + # impose block structure |
| 203 | + b = PseudoBlockArray(dest, (axes(A,1),)) |
| 204 | + |
| 205 | + N = blocksize(A,1) |
| 206 | + |
| 207 | + for K = N:-1:1 |
| 208 | + b_2 = view(b, Block(K)) |
| 209 | + Ũ = _triangular_matrix(Val('U'), Val($UNIT), view(A, Block(K,K))) |
| 210 | + materialize!(Ldiv(Ũ, b_2)) |
| 211 | + |
| 212 | + if K ≥ 2 |
| 213 | + KR = blockcolstart(A, K):Block(K-1) |
| 214 | + V_12 = view(A, KR, Block(K)) |
| 215 | + b̃_1 = view(b, KR) |
| 216 | + muladd!(-one(T), V_12, b_2, one(T), b̃_1) |
| 217 | + end |
| 218 | + end |
| 219 | + |
| 220 | + dest |
| 221 | + end |
| 222 | + |
| 223 | + @inline function materialize!(M::MatLdivVec{<:TriangularLayout{'L',$UNIT,<:AbstractBlockLayout}, |
| 224 | + <:AbstractStridedLayout}) |
| 225 | + L,dest = M.A, M.B |
| 226 | + T = eltype(dest) |
| 227 | + A = triangulardata(L) |
| 228 | + if !hasmatchingblocks(A) # Use default for now |
| 229 | + return materialize!(Ldiv{TriangularLayout{'L',$UNIT,UnknownLayout}, |
| 230 | + typeof(MemoryLayout(dest))}(L, dest)) |
| 231 | + end |
| 232 | + |
| 233 | + |
| 234 | + @boundscheck size(A,1) == size(dest,1) || throw(BoundsError()) |
| 235 | + |
| 236 | + # impose block structure |
| 237 | + b = PseudoBlockArray(dest, (axes(A,1),)) |
| 238 | + |
| 239 | + N = blocksize(A,1) |
| 240 | + |
| 241 | + for K = 1:N |
| 242 | + b_2 = view(b, Block(K)) |
| 243 | + L̃ = _triangular_matrix(Val('L'), Val($UNIT), view(A, Block(K,K))) |
| 244 | + materialize!(Ldiv(L̃, b_2)) |
| 245 | + |
| 246 | + if K < N |
| 247 | + KR = Block(K+1):blockcolstop(A, K) |
| 248 | + V_12 = view(A, KR, Block(K)) |
| 249 | + b̃_1 = view(b, KR) |
| 250 | + muladd!(-one(T), V_12, b_2, one(T), b̃_1) |
| 251 | + end |
| 252 | + end |
| 253 | + |
| 254 | + dest |
| 255 | + end |
| 256 | + end |
| 257 | +end |
0 commit comments