Skip to content

Commit bdd01ca

Browse files
authored
Replace broadcast by loop in LinearAlgebra.exp! for StridedMatrixes (#54520)
1 parent 047e699 commit bdd01ca

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -705,22 +705,29 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
705705
V = mul!(C[3]*P, true, C[1]*I, true, true) #V = C[1]*I + C[3]*P
706706
for k in 2:(div(length(C), 2) - 1)
707707
P *= A2
708-
U .+= C[2k + 2] .* P
709-
V .+= C[2k + 1] .* P
708+
for ind in eachindex(P)
709+
U[ind] += C[2k + 2] * P[ind]
710+
V[ind] += C[2k + 1] * P[ind]
711+
end
710712
end
711713

712714
U = A * U
713715

714716
# Padé approximant: (V-U)\(V+U)
715717
tmp1, tmp2 = A, A2 # Reuse already allocated arrays
716-
tmp1 .= V .- U
717-
tmp2 .= V .+ U
718+
for ind in eachindex(tmp1)
719+
tmp1[ind] = V[ind] - U[ind]
720+
tmp2[ind] = V[ind] + U[ind]
721+
end
718722
X = LAPACK.gesv!(tmp1, tmp2)[1]
719723
else
720724
s = log2(nA/5.4) # power of 2 later reversed by squaring
721725
if s > 0
722726
si = ceil(Int,s)
723-
A ./= convert(T,2^si)
727+
twopowsi = convert(T,2^si)
728+
for ind in eachindex(A)
729+
A[ind] /= twopowsi
730+
end
724731
end
725732
CC = T[64764752532480000.,32382376266240000.,7771770303897600.,
726733
1187353796428800., 129060195264000., 10559470521600.,
@@ -735,22 +742,28 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
735742
# Allocation economical version of:
736743
# U = A * (A6 * (CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2) .+
737744
# CC[8].*A6 .+ CC[6].*A4 .+ CC[4]*A2+CC[2]*I)
738-
tmp1 .= CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2
739-
tmp2 .= CC[8].*A6 .+ CC[6].*A4 .+ CC[4].*A2
745+
for ind in eachindex(tmp1)
746+
tmp1[ind] = CC[14]*A6[ind] + CC[12]*A4[ind] + CC[10]*A2[ind]
747+
tmp2[ind] = CC[8]*A6[ind] + CC[6]*A4[ind] + CC[4]*A2[ind]
748+
end
740749
mul!(tmp2, true,CC[2]*I, true, true) # tmp2 .+= CC[2]*I
741750
U = mul!(tmp2, A6, tmp1, true, true)
742751
U, tmp1 = mul!(tmp1, A, U), A # U = A * U0
743752

744753
# Allocation economical version of:
745754
# V = A6 * (CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2) .+
746755
# CC[7].*A6 .+ CC[5].*A4 .+ CC[3]*A2 .+ CC[1]*I
747-
tmp1 .= CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2
748-
tmp2 .= CC[7].*A6 .+ CC[5].*A4 .+ CC[3].*A2
756+
for ind in eachindex(tmp1)
757+
tmp1[ind] = CC[13]*A6[ind] + CC[11]*A4[ind] + CC[9]*A2[ind]
758+
tmp2[ind] = CC[7]*A6[ind] + CC[5]*A4[ind] + CC[3]*A2[ind]
759+
end
749760
mul!(tmp2, true, CC[1]*I, true, true) # tmp2 .+= CC[1]*I
750761
V = mul!(tmp2, A6, tmp1, true, true)
751762

752-
tmp1 .= V .+ U
753-
tmp2 .= V .- U # tmp2 already contained V but this seems more readable
763+
for ind in eachindex(tmp1)
764+
tmp1[ind] = V[ind] + U[ind]
765+
tmp2[ind] = V[ind] - U[ind] # tmp2 already contained V but this seems more readable
766+
end
754767
X = LAPACK.gesv!(tmp2, tmp1)[1] # X now contains r_13 in Higham 2008
755768

756769
if s > 0

0 commit comments

Comments
 (0)