Skip to content

Commit 9174e9a

Browse files
Introduce a thread-safe butterfly algorithm and multi-threaded fast spherical harmonic transforms
1 parent afceb5b commit 9174e9a

File tree

4 files changed

+24
-23
lines changed

4 files changed

+24
-23
lines changed

src/FastTransforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import Base: getindex, setindex!, Factorization, length
1919
import Base.LinAlg: BlasFloat, BlasInt
2020
import HierarchicalMatrices: HierarchicalMatrix, unsafe_broadcasttimes!
2121
import HierarchicalMatrices: A_mul_B!, At_mul_B!, Ac_mul_B!
22+
import HierarchicalMatrices: ThreadSafeVector, threadsafezeros
2223
import LowRankApprox: ColPerm
2324
import AbstractFFTs: Plan
2425

src/SphericalHarmonics/Butterfly.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ struct Butterfly{T} <: Factorization{T}
33
factors::Vector{Vector{IDPackedV{T}}}
44
permutations::Vector{Vector{ColumnPermutation}}
55
indices::Vector{Vector{Int}}
6-
temp1::Vector{T}
7-
temp2::Vector{T}
8-
temp3::Vector{T}
9-
temp4::Vector{T}
6+
temp1::ThreadSafeVector{T}
7+
temp2::ThreadSafeVector{T}
8+
temp3::ThreadSafeVector{T}
9+
temp4::ThreadSafeVector{T}
1010
end
1111

1212
function size(B::Butterfly, dim::Integer)
@@ -106,7 +106,7 @@ function Butterfly{T}(A::AbstractMatrix{T}, L::Int; isorthogonal::Bool = false,
106106

107107
kk = sumkmax(indices)
108108

109-
Butterfly(columns, factors, permutations, indices, zeros(T, kk), zeros(T, kk), zeros(T, kk), zeros(T, kk))
109+
Butterfly(columns, factors, permutations, indices, threadedzeros(T, kk), threadedzeros(T, kk), threadedzeros(T, kk), threadedzeros(T, kk))
110110
end
111111

112112
function sumkmax(indices::Vector{Vector{Int}})
@@ -119,7 +119,7 @@ end
119119

120120
#### Helper
121121

122-
function rowperm!(fwd::Bool, x::StridedVecOrMat, p::Vector{Int}, jstart::Int)
122+
function rowperm!(fwd::Bool, x::AbstractVecOrMat, p::Vector{Int}, jstart::Int)
123123
n = length(p)
124124
jshift = jstart-1
125125
scale!(p, -1)
@@ -151,7 +151,7 @@ function rowperm!(fwd::Bool, x::StridedVecOrMat, p::Vector{Int}, jstart::Int)
151151
x
152152
end
153153

154-
function rowperm!(fwd::Bool, y::StridedVector, x::StridedVector, p::Vector{Int}, jstart::Int)
154+
function rowperm!(fwd::Bool, y::AbstractVector, x::AbstractVector, p::Vector{Int}, jstart::Int)
155155
n = length(p)
156156
jshift = jstart-1
157157
@inbounds if (fwd)
@@ -167,13 +167,13 @@ function rowperm!(fwd::Bool, y::StridedVector, x::StridedVector, p::Vector{Int},
167167
end
168168

169169
## ColumnPermutation
170-
A_mul_B!(A::ColPerm, B::StridedVecOrMat, jstart::Int) = rowperm!(false, B, A.p, jstart)
171-
At_mul_B!(A::ColPerm, B::StridedVecOrMat, jstart::Int) = rowperm!(true, B, A.p, jstart)
172-
Ac_mul_B!(A::ColPerm, B::StridedVecOrMat, jstart::Int) = At_mul_B!(A, B, jstart)
170+
A_mul_B!(A::ColPerm, B::AbstractVecOrMat, jstart::Int) = rowperm!(false, B, A.p, jstart)
171+
At_mul_B!(A::ColPerm, B::AbstractVecOrMat, jstart::Int) = rowperm!(true, B, A.p, jstart)
172+
Ac_mul_B!(A::ColPerm, B::AbstractVecOrMat, jstart::Int) = At_mul_B!(A, B, jstart)
173173

174-
A_mul_B!(y::StridedVector, A::ColPerm, x::StridedVector, jstart::Int) = rowperm!(false, y, x, A.p, jstart)
175-
At_mul_B!(y::StridedVector, A::ColPerm, x::StridedVector, jstart::Int) = rowperm!(true, y, x, A.p, jstart)
176-
Ac_mul_B!(y::StridedVector, A::ColPerm, x::StridedVector, jstart::Int) = At_mul_B!(y, x, A, jstart)
174+
A_mul_B!(y::AbstractVector, A::ColPerm, x::AbstractVector, jstart::Int) = rowperm!(false, y, x, A.p, jstart)
175+
At_mul_B!(y::AbstractVector, A::ColPerm, x::AbstractVector, jstart::Int) = rowperm!(true, y, x, A.p, jstart)
176+
Ac_mul_B!(y::AbstractVector, A::ColPerm, x::AbstractVector, jstart::Int) = At_mul_B!(y, x, A, jstart)
177177

178178
# Fast A_mul_B!, At_mul_B!, and Ac_mul_B! for an ID. These overwrite the output.
179179

@@ -339,7 +339,7 @@ for f! in (:At_mul_B!,:Ac_mul_B!)
339339
end
340340
end
341341

342-
function addtemp3totemp2!(temp2::Vector, temp3::Vector, i1::Int, i2::Int)
342+
function addtemp3totemp2!(temp2::AbstractVector, temp3::AbstractVector, i1::Int, i2::Int)
343343
z = zero(eltype(temp3))
344344
@inbounds @simd for i = i1:i2
345345
temp2[i] += temp3[i]

src/SphericalHarmonics/thinplan.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function Base.A_mul_B!(Y::Matrix, TP::ThinSphericalHarmonicPlan, X::Matrix)
4545
copy!(B, X)
4646
M, N = size(X)
4747

48-
for J = 3:2:N÷2
48+
@stepthreads for J = 3:2:N÷2
4949
if checklayer(J-1)
5050
A_mul_B_col_J!(Y, BF[J-1], B, 2J)
5151
2J < N && A_mul_B_col_J!(Y, BF[J-1], B, 2J+1)
@@ -62,7 +62,7 @@ function Base.A_mul_B!(Y::Matrix, TP::ThinSphericalHarmonicPlan, X::Matrix)
6262
end
6363
end
6464

65-
for J = 2:2:N÷2
65+
@stepthreads for J = 2:2:N÷2
6666
if checklayer(J)
6767
A_mul_B_col_J!(Y, BF[J-1], B, 2J)
6868
2J < N && A_mul_B_col_J!(Y, BF[J-1], B, 2J+1)
@@ -84,11 +84,11 @@ function Base.A_mul_B!(Y::Matrix, TP::ThinSphericalHarmonicPlan, X::Matrix)
8484
fill!(Y, zero(eltype(Y)))
8585

8686
A_mul_B_col_J!!(Y, p1, B, 1)
87-
for J = 2:4:N
87+
@stepthreads for J = 2:4:N
8888
A_mul_B_col_J!!(Y, p2, B, J)
8989
J < N && A_mul_B_col_J!!(Y, p2, B, J+1)
9090
end
91-
for J = 4:4:N
91+
@stepthreads for J = 4:4:N
9292
A_mul_B_col_J!!(Y, p1, B, J)
9393
J < N && A_mul_B_col_J!!(Y, p1, B, J+1)
9494
end
@@ -101,11 +101,11 @@ function Base.At_mul_B!(Y::Matrix, TP::ThinSphericalHarmonicPlan, X::Matrix)
101101
copy!(B, X)
102102
M, N = size(X)
103103
A_mul_B_col_J!!(Y, p1inv, B, 1)
104-
for J = 2:4:N
104+
@stepthreads for J = 2:4:N
105105
A_mul_B_col_J!!(Y, p2inv, B, J)
106106
J < N && A_mul_B_col_J!!(Y, p2inv, B, J+1)
107107
end
108-
for J = 4:4:N
108+
@stepthreads for J = 4:4:N
109109
A_mul_B_col_J!!(Y, p1inv, B, J)
110110
J < N && A_mul_B_col_J!!(Y, p1inv, B, J+1)
111111
end
@@ -114,7 +114,7 @@ function Base.At_mul_B!(Y::Matrix, TP::ThinSphericalHarmonicPlan, X::Matrix)
114114
fill!(Y, zero(eltype(Y)))
115115
copy!(Y, 1, B, 1, 3M)
116116

117-
for J = 3:2:N÷2
117+
@stepthreads for J = 3:2:N÷2
118118
if checklayer(J-1)
119119
At_mul_B_col_J!(Y, BF[J-1], B, 2J)
120120
2J < N && At_mul_B_col_J!(Y, BF[J-1], B, 2J+1)
@@ -131,7 +131,7 @@ function Base.At_mul_B!(Y::Matrix, TP::ThinSphericalHarmonicPlan, X::Matrix)
131131
end
132132
end
133133

134-
for J = 2:2:N÷2
134+
@stepthreads for J = 2:2:N÷2
135135
if checklayer(J)
136136
At_mul_B_col_J!(Y, BF[J-1], B, 2J)
137137
2J < N && At_mul_B_col_J!(Y, BF[J-1], B, 2J+1)

test/basictests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
x = rand(T, n)
5959
y = zeros(T, k)
6060

61-
@test FastTransforms.A_mul_B!(y, A, P, x, 1, 1) == A*x
61+
@test norm(FastTransforms.A_mul_B!(y, A, P, x, 1, 1) - A*x) < 10eps()*norm(A*x)
6262

6363
x = rand(T, k)
6464
y = zeros(T, n)

0 commit comments

Comments
 (0)