Skip to content

Commit f6a3140

Browse files
daanhbMikaelSlevinsky
authored andcommitted
Make sure FFTW is maximally used (#66)
1 parent 769510b commit f6a3140

File tree

2 files changed

+143
-27
lines changed

2 files changed

+143
-27
lines changed

src/fftBigFloat.jl

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
const AbstractFloats = Union{AbstractFloat,Complex{T} where T<:AbstractFloat}
22

3+
# We use these type definitions for clarity
4+
const RealFloats = T where T<:AbstractFloat
5+
const ComplexFloats = Complex{T} where T<:AbstractFloat
6+
37
if VERSION < v"0.7-"
48
import Base.FFTW: fft, fft!, rfft, irfft, ifft, ifft!, conv, dct, idct, dct!, idct!,
5-
plan_fft!, plan_ifft!, plan_dct!, plan_idct!,
6-
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct
9+
plan_fft!, plan_ifft!, plan_dct!, plan_idct!, plan_bfft, plan_bfft!,
10+
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct,
11+
plan_brfft
712
else
813
import FFTW: dct, dct!, idct, idct!,
914
plan_fft!, plan_ifft!, plan_dct!, plan_idct!,
10-
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct
15+
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct,
16+
plan_bfft, plan_bfft!, plan_brfft
1117
import DSP: conv
1218
end
1319

1420
# The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html
1521
# To add more types, add them in the union of the function's signature.
1622

1723
function generic_fft(x::Vector{T}) where T<:AbstractFloats
24+
T <: FFTW.fftwNumber && (@warn("Using generic fft for FFTW number type."))
1825
n = length(x)
1926
ispow2(n) && return generic_fft_pow2(x)
2027
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
@@ -30,18 +37,24 @@ function generic_fft!(x::Vector{T}) where T<:AbstractFloats
3037
end
3138

3239
# add rfft for AbstractFloat, by calling fft
33-
# this creates ToeplitzMatrices.rfft, so avoids changing rfft
34-
3540
generic_rfft(v::Vector{T}) where T<:AbstractFloats = generic_fft(v)[1:div(length(v),2)+1]
3641

37-
function generic_irfft(v::Vector{T},n::Integer) where T<:AbstractFloats
42+
function generic_irfft(v::Vector{T}, n::Integer) where T<:ComplexFloats
3843
@assert n==2length(v)-1
39-
r = Vector{Complex{real(T)}}(undef, n)
44+
r = Vector{T}(undef, n)
4045
r[1:length(v)]=v
4146
r[length(v)+1:end]=reverse(conj(v[2:end]))
4247
real(generic_ifft(r))
4348
end
4449

50+
generic_bfft(x::Vector{T}) where {T <: AbstractFloats} = conj!(generic_fft(conj(x)))
51+
function generic_bfft!(x::Vector{T}) where {T <: AbstractFloats}
52+
x[:] = generic_bfft(x)
53+
return x
54+
end
55+
56+
generic_brfft(v::Vector, n::Integer) = generic_irfft(v, n)*n
57+
4558
generic_ifft(x::Vector{T}) where {T<:AbstractFloats} = conj!(generic_fft(conj(x)))/length(x)
4659
function generic_ifft!(x::Vector{T}) where T<:AbstractFloats
4760
x[:] = generic_ifft(x)
@@ -112,6 +125,7 @@ function generic_ifft_pow2(x::Vector{Complex{T}}) where T<:AbstractFloat
112125
end
113126

114127
function generic_dct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
128+
T <: FFTW.fftwNumber && (@warn("Using generic dct for FFTW number type."))
115129
N = length(a)
116130
twoN = convert(T,2) * N
117131
c = generic_fft([a; flipdim(a,1)])
@@ -124,6 +138,7 @@ end
124138
generic_dct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_dct(complex(a)))
125139

126140
function generic_idct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
141+
T <: FFTW.fftwNumber && (@warn("Using generic idct for FFTW number type."))
127142
N = length(a)
128143
twoN = convert(T,2)*N
129144
b = a * sqrt(twoN)
@@ -150,35 +165,39 @@ for f in (:dct, :dct!, :idct, :idct!)
150165
end
151166

152167
# dummy plans
153-
struct DummyFFTPlan{T,inplace} <: Plan{T} end
154-
struct DummyiFFTPlan{T,inplace} <: Plan{T} end
155-
struct DummyDCTPlan{T,inplace} <: Plan{T} end
156-
struct DummyiDCTPlan{T,inplace} <: Plan{T} end
157-
struct DummyrFFTPlan{T,inplace} <: Plan{T}
168+
abstract type DummyPlan{T} <: Plan{T} end
169+
struct DummyFFTPlan{T,inplace} <: DummyPlan{T} end
170+
struct DummyiFFTPlan{T,inplace} <: DummyPlan{T} end
171+
struct DummybFFTPlan{T,inplace} <: DummyPlan{T} end
172+
struct DummyDCTPlan{T,inplace} <: DummyPlan{T} end
173+
struct DummyiDCTPlan{T,inplace} <: DummyPlan{T} end
174+
struct DummyrFFTPlan{T,inplace} <: DummyPlan{T}
158175
n :: Integer
159176
end
160-
struct DummyirFFTPlan{T,inplace} <: Plan{T}
177+
struct DummyirFFTPlan{T,inplace} <: DummyPlan{T}
178+
n :: Integer
179+
end
180+
struct DummybrFFTPlan{T,inplace} <: DummyPlan{T}
161181
n :: Integer
162182
end
163183

164184
for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
165-
# (:DummyrFFTPlan,:DummyirFFTPlan),
166185
(:DummyDCTPlan,:DummyiDCTPlan))
167186
@eval begin
168187
Base.inv(::$Plan{T,inplace}) where {T,inplace} = $iPlan{T,inplace}()
169188
Base.inv(::$iPlan{T,inplace}) where {T,inplace} = $Plan{T,inplace}()
170189
end
171190
end
172191

173-
# Specific for rfft and irfft:
192+
# Specific for rfft, irfft and brfft:
174193
Base.inv(::DummyirFFTPlan{T,inplace}) where {T,inplace} = DummyrFFTPlan{T,Inplace}(p.n)
175194
Base.inv(::DummyrFFTPlan{T,inplace}) where {T,inplace} = DummyirFFTPlan{T,Inplace}(p.n)
176195

177196

178197
for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
198+
(:DummybFFTPlan,:generic_bfft,:generic_bfft!),
179199
(:DummyiFFTPlan,:generic_ifft,:generic_ifft!),
180200
(:DummyrFFTPlan,:generic_rfft,:generic_rfft!),
181-
# (:DummyirFFTPlan,:generic_irfft,:generic_irfft!),
182201
(:DummyDCTPlan,:generic_dct,:generic_dct!),
183202
(:DummyiDCTPlan,:generic_idct,:generic_idct!))
184203
@eval begin
@@ -191,13 +210,20 @@ for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
191210
end
192211
end
193212

194-
# Specific for irfft:
213+
# Specific for irfft and brfft:
195214
*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n)
196215
*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n)
197216
function LAmul!(C::StridedVector, p::DummyirFFTPlan, x::StridedVector)
198217
C[:] = generic_irfft(x, p.n)
199218
C
200219
end
220+
*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n)
221+
*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n)
222+
function LAmul!(C::StridedVector, p::DummybrFFTPlan, x::StridedVector)
223+
C[:] = generic_brfft(x, p.n)
224+
C
225+
end
226+
201227

202228
# We override these for AbstractFloat, so that conversion from reals to
203229
# complex numbers works for any AbstractFloat (instead of only BlasFloat's)
@@ -207,21 +233,38 @@ AbstractFFTs.realfloat(x::StridedArray{<:Real}) = x
207233
# unsupported (as defined in AbstractFFTs)
208234
AbstractFFTs._fftfloat(::Type{T}) where {T <: AbstractFloat} = T
209235

210-
plan_fft!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyFFTPlan{Complex{real(T)},true}()
211-
plan_ifft!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiFFTPlan{Complex{real(T)},true}()
212236

213-
# plan_rfft!(x::StridedArray{T}) where {T <: AbstractFloat} = DummyrFFTPlan{Complex{real(T)},true}()
214-
# plan_irfft!(x::StridedArray{T},n::Integer) where {T <: AbstractFloat} = DummyirFFTPlan{Complex{real(T)},true}()
215-
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true}()
216-
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true}()
237+
# We intercept the calls to plan_X(x, region) below.
238+
# In order not to capture any calls that should go to FFTW, we have to be
239+
# careful about the typing, so that the calls to FFTW remain more specific.
240+
# This is the reason for using StridedArray below. We also have to carefully
241+
# distinguish between real and complex arguments.
242+
243+
plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false}()
244+
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true}()
245+
246+
plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false}()
247+
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true}()
248+
249+
# The ifft plans are automatically provided in terms of the bfft plans above.
250+
# plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false}()
251+
# plan_ifft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},true}()
217252

218-
plan_fft(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyFFTPlan{Complex{real(T)},false}()
219-
plan_ifft(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiFFTPlan{Complex{real(T)},false}()
220-
plan_rfft(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyrFFTPlan{Complex{real(T)},false}(length(x))
221-
plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: AbstractFloats} = DummyirFFTPlan{Complex{real(T)},false}(n)
222253
plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false}()
254+
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true}()
255+
223256
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false}()
257+
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true}()
258+
259+
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false}(length(x))
260+
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false}(n)
261+
262+
# A plan for irfft is created in terms of a plan for brfft.
263+
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false}(n)
224264

265+
# These don't exist for now:
266+
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()
267+
# plan_irfft!(x::StridedArray{T},n::Integer) where {T <: RealFloats} = DummyirFFTPlan{Complex{real(T)},true}()
225268

226269
function interlace(a::Vector{S},b::Vector{V}) where {S<:Number,V<:Number}
227270
na=length(a);nb=length(b)

test/fftBigFloattests.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,77 @@ end
2929
@test norm(idct(c)-idct(map(ComplexF64,c)),Inf) < 10eps()
3030
@test norm(idct(dct(c))-c,Inf) < 1000eps(BigFloat)
3131
@test norm(dct(idct(c))-c,Inf) < 1000eps(BigFloat)
32+
33+
# Make sure we don't accidentally hijack any FFTW plans
34+
for T in (Float32, Float64)
35+
@test plan_fft(rand(BigFloat,10)) isa FastTransforms.DummyPlan
36+
@test plan_fft(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
37+
@test plan_fft(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
38+
@test plan_fft(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
39+
@test plan_fft!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
40+
@test plan_fft!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
41+
@test !( plan_fft(rand(T,10)) isa FastTransforms.DummyPlan )
42+
@test !( plan_fft(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
43+
@test !( plan_fft(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
44+
@test !( plan_fft(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
45+
@test !( plan_fft!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
46+
@test !( plan_fft!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
47+
48+
@test plan_ifft(rand(T,10)) isa FFTW.ScaledPlan
49+
@test plan_ifft(rand(T,10), 1:1) isa FFTW.ScaledPlan
50+
@test plan_ifft(rand(Complex{T},10)) isa FFTW.ScaledPlan
51+
@test plan_ifft(rand(Complex{T},10), 1:1) isa FFTW.ScaledPlan
52+
@test plan_ifft!(rand(Complex{T},10)) isa FFTW.ScaledPlan
53+
@test plan_ifft!(rand(Complex{T},10), 1:1) isa FFTW.ScaledPlan
54+
55+
@test plan_bfft(rand(BigFloat,10)) isa FastTransforms.DummyPlan
56+
@test plan_bfft(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
57+
@test plan_bfft(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
58+
@test plan_bfft(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
59+
@test plan_bfft!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
60+
@test plan_bfft!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
61+
@test !( plan_bfft(rand(T,10)) isa FastTransforms.DummyPlan )
62+
@test !( plan_bfft(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
63+
@test !( plan_bfft(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
64+
@test !( plan_bfft(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
65+
@test !( plan_bfft!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
66+
@test !( plan_bfft!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
67+
68+
@test plan_dct(rand(BigFloat,10)) isa FastTransforms.DummyPlan
69+
@test plan_dct(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
70+
@test plan_dct(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
71+
@test plan_dct(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
72+
@test plan_dct!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
73+
@test plan_dct!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
74+
@test !( plan_dct(rand(T,10)) isa FastTransforms.DummyPlan )
75+
@test !( plan_dct(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
76+
@test !( plan_dct(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
77+
@test !( plan_dct(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
78+
@test !( plan_dct!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
79+
@test !( plan_dct!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
80+
81+
@test plan_idct(rand(BigFloat,10)) isa FastTransforms.DummyPlan
82+
@test plan_idct(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
83+
@test plan_idct(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
84+
@test plan_idct(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
85+
@test plan_idct!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
86+
@test plan_idct!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
87+
@test !( plan_idct(rand(T,10)) isa FastTransforms.DummyPlan )
88+
@test !( plan_idct(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
89+
@test !( plan_idct(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
90+
@test !( plan_idct(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
91+
@test !( plan_idct!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
92+
@test !( plan_idct!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
93+
94+
@test plan_rfft(rand(BigFloat,10)) isa FastTransforms.DummyPlan
95+
@test plan_rfft(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
96+
@test plan_brfft(rand(Complex{BigFloat},10), 19) isa FastTransforms.DummyPlan
97+
@test plan_brfft(rand(Complex{BigFloat},10), 19, 1:1) isa FastTransforms.DummyPlan
98+
@test !( plan_rfft(rand(T,10)) isa FastTransforms.DummyPlan )
99+
@test !( plan_rfft(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
100+
@test !( plan_brfft(rand(Complex{T},10), 19) isa FastTransforms.DummyPlan )
101+
@test !( plan_brfft(rand(Complex{T},10), 19, 1:1) isa FastTransforms.DummyPlan )
102+
103+
end
104+
32105
end

0 commit comments

Comments
 (0)