Skip to content

Commit fbc7d72

Browse files
authored
muladd! with adjoints/transposes (#233)
1 parent f9fd686 commit fbc7d72

File tree

4 files changed

+14
-2
lines changed

4 files changed

+14
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.9.2"
4+
version = "1.9.3"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ Base.permutedims(D::Diagonal{<:Any,<:LayoutVector}) = D
285285

286286
zero!(A) = zero!(MemoryLayout(A), A)
287287
zero!(_, A) = fill!(A,zero(eltype(A)))
288+
function zero!(::DualLayout, A)
289+
zero!(parent(A))
290+
A
291+
end
288292
function zero!(_, A::AbstractArray{<:AbstractArray})
289293
for a in A
290294
zero!(a)

src/muladd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ materialize(M::MulAdd) = copy(instantiate(M))
7777
copy(M::MulAdd) = copyto!(similar(M), M)
7878

7979
_fill_copyto!(dest, C) = copyto!(dest, C)
80-
_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload
80+
_fill_copyto!(dest, C::Union{Zeros,AdjOrTrans{<:Any,<:Zeros}}) = zero!(dest) # exploit special fill! overload
8181

8282
@inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T =
8383
muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C); Czero = M.Czero)

test/test_muladd.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,14 @@ Random.seed!(0)
844844
@test copy(M) b * D * α + c * β
845845
end
846846
end
847+
848+
@testset "dual" begin
849+
a = randn(5)
850+
X = randn(5,6)
851+
@test copyto!(similar(a,6)', MulAdd(2.0, a', X, 3.0, Zeros(6)')) 2a'*X
852+
@test copyto!(transpose(similar(a,6)), MulAdd(2.0, a', X, 3.0, Zeros(6)')) 2a'*X
853+
@test copyto!(transpose(similar(a,6)), MulAdd(2.0, transpose(a), X, 3.0, transpose(Zeros(6)))) 2a'*X
854+
end
847855
end
848856

849857
end

0 commit comments

Comments
 (0)