Skip to content

Make sure FFTW is maximally used #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 70 additions & 27 deletions src/fftBigFloat.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
const AbstractFloats = Union{AbstractFloat,Complex{T} where T<:AbstractFloat}

# We use these type definitions for clarity
const RealFloats = T where T<:AbstractFloat
const ComplexFloats = Complex{T} where T<:AbstractFloat

if VERSION < v"0.7-"
import Base.FFTW: fft, fft!, rfft, irfft, ifft, ifft!, conv, dct, idct, dct!, idct!,
plan_fft!, plan_ifft!, plan_dct!, plan_idct!,
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct
plan_fft!, plan_ifft!, plan_dct!, plan_idct!, plan_bfft, plan_bfft!,
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct,
plan_brfft
else
import FFTW: dct, dct!, idct, idct!,
plan_fft!, plan_ifft!, plan_dct!, plan_idct!,
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct,
plan_bfft, plan_bfft!, plan_brfft
import DSP: conv
end

# The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html
# To add more types, add them in the union of the function's signature.

function generic_fft(x::Vector{T}) where T<:AbstractFloats
T <: FFTW.fftwNumber && (@warn("Using generic fft for FFTW number type."))
n = length(x)
ispow2(n) && return generic_fft_pow2(x)
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
Expand All @@ -30,18 +37,24 @@ function generic_fft!(x::Vector{T}) where T<:AbstractFloats
end

# add rfft for AbstractFloat, by calling fft
# this creates ToeplitzMatrices.rfft, so avoids changing rfft

generic_rfft(v::Vector{T}) where T<:AbstractFloats = generic_fft(v)[1:div(length(v),2)+1]

function generic_irfft(v::Vector{T},n::Integer) where T<:AbstractFloats
function generic_irfft(v::Vector{T}, n::Integer) where T<:ComplexFloats
@assert n==2length(v)-1
r = Vector{Complex{real(T)}}(undef, n)
r = Vector{T}(undef, n)
r[1:length(v)]=v
r[length(v)+1:end]=reverse(conj(v[2:end]))
real(generic_ifft(r))
end

generic_bfft(x::Vector{T}) where {T <: AbstractFloats} = conj!(generic_fft(conj(x)))
function generic_bfft!(x::Vector{T}) where {T <: AbstractFloats}
x[:] = generic_bfft(x)
return x
end

generic_brfft(v::Vector, n::Integer) = generic_irfft(v, n)*n

generic_ifft(x::Vector{T}) where {T<:AbstractFloats} = conj!(generic_fft(conj(x)))/length(x)
function generic_ifft!(x::Vector{T}) where T<:AbstractFloats
x[:] = generic_ifft(x)
Expand Down Expand Up @@ -112,6 +125,7 @@ function generic_ifft_pow2(x::Vector{Complex{T}}) where T<:AbstractFloat
end

function generic_dct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
T <: FFTW.fftwNumber && (@warn("Using generic dct for FFTW number type."))
N = length(a)
twoN = convert(T,2) * N
c = generic_fft([a; flipdim(a,1)])
Expand All @@ -124,6 +138,7 @@ end
generic_dct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_dct(complex(a)))

function generic_idct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
T <: FFTW.fftwNumber && (@warn("Using generic idct for FFTW number type."))
N = length(a)
twoN = convert(T,2)*N
b = a * sqrt(twoN)
Expand All @@ -150,35 +165,39 @@ for f in (:dct, :dct!, :idct, :idct!)
end

# dummy plans
struct DummyFFTPlan{T,inplace} <: Plan{T} end
struct DummyiFFTPlan{T,inplace} <: Plan{T} end
struct DummyDCTPlan{T,inplace} <: Plan{T} end
struct DummyiDCTPlan{T,inplace} <: Plan{T} end
struct DummyrFFTPlan{T,inplace} <: Plan{T}
abstract type DummyPlan{T} <: Plan{T} end
struct DummyFFTPlan{T,inplace} <: DummyPlan{T} end
struct DummyiFFTPlan{T,inplace} <: DummyPlan{T} end
struct DummybFFTPlan{T,inplace} <: DummyPlan{T} end
struct DummyDCTPlan{T,inplace} <: DummyPlan{T} end
struct DummyiDCTPlan{T,inplace} <: DummyPlan{T} end
struct DummyrFFTPlan{T,inplace} <: DummyPlan{T}
n :: Integer
end
struct DummyirFFTPlan{T,inplace} <: Plan{T}
struct DummyirFFTPlan{T,inplace} <: DummyPlan{T}
n :: Integer
end
struct DummybrFFTPlan{T,inplace} <: DummyPlan{T}
n :: Integer
end

for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
# (:DummyrFFTPlan,:DummyirFFTPlan),
(:DummyDCTPlan,:DummyiDCTPlan))
@eval begin
Base.inv(::$Plan{T,inplace}) where {T,inplace} = $iPlan{T,inplace}()
Base.inv(::$iPlan{T,inplace}) where {T,inplace} = $Plan{T,inplace}()
end
end

# Specific for rfft and irfft:
# Specific for rfft, irfft and brfft:
Base.inv(::DummyirFFTPlan{T,inplace}) where {T,inplace} = DummyrFFTPlan{T,Inplace}(p.n)
Base.inv(::DummyrFFTPlan{T,inplace}) where {T,inplace} = DummyirFFTPlan{T,Inplace}(p.n)


for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
(:DummybFFTPlan,:generic_bfft,:generic_bfft!),
(:DummyiFFTPlan,:generic_ifft,:generic_ifft!),
(:DummyrFFTPlan,:generic_rfft,:generic_rfft!),
# (:DummyirFFTPlan,:generic_irfft,:generic_irfft!),
(:DummyDCTPlan,:generic_dct,:generic_dct!),
(:DummyiDCTPlan,:generic_idct,:generic_idct!))
@eval begin
Expand All @@ -191,13 +210,20 @@ for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
end
end

# Specific for irfft:
# Specific for irfft and brfft:
*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n)
*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n)
function LAmul!(C::StridedVector, p::DummyirFFTPlan, x::StridedVector)
C[:] = generic_irfft(x, p.n)
C
end
*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n)
*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n)
function LAmul!(C::StridedVector, p::DummybrFFTPlan, x::StridedVector)
C[:] = generic_brfft(x, p.n)
C
end


# We override these for AbstractFloat, so that conversion from reals to
# complex numbers works for any AbstractFloat (instead of only BlasFloat's)
Expand All @@ -207,21 +233,38 @@ AbstractFFTs.realfloat(x::StridedArray{<:Real}) = x
# unsupported (as defined in AbstractFFTs)
AbstractFFTs._fftfloat(::Type{T}) where {T <: AbstractFloat} = T

plan_fft!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyFFTPlan{Complex{real(T)},true}()
plan_ifft!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiFFTPlan{Complex{real(T)},true}()

# plan_rfft!(x::StridedArray{T}) where {T <: AbstractFloat} = DummyrFFTPlan{Complex{real(T)},true}()
# plan_irfft!(x::StridedArray{T},n::Integer) where {T <: AbstractFloat} = DummyirFFTPlan{Complex{real(T)},true}()
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true}()
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true}()
# We intercept the calls to plan_X(x, region) below.
# In order not to capture any calls that should go to FFTW, we have to be
# careful about the typing, so that the calls to FFTW remain more specific.
# This is the reason for using StridedArray below. We also have to carefully
# distinguish between real and complex arguments.

plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false}()
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true}()

plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false}()
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true}()

# The ifft plans are automatically provided in terms of the bfft plans above.
# plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false}()
# plan_ifft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},true}()

plan_fft(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyFFTPlan{Complex{real(T)},false}()
plan_ifft(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiFFTPlan{Complex{real(T)},false}()
plan_rfft(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyrFFTPlan{Complex{real(T)},false}(length(x))
plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: AbstractFloats} = DummyirFFTPlan{Complex{real(T)},false}(n)
plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false}()
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true}()

plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false}()
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true}()

plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false}(length(x))
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false}(n)

# A plan for irfft is created in terms of a plan for brfft.
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false}(n)

# These don't exist for now:
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()
# plan_irfft!(x::StridedArray{T},n::Integer) where {T <: RealFloats} = DummyirFFTPlan{Complex{real(T)},true}()

function interlace(a::Vector{S},b::Vector{V}) where {S<:Number,V<:Number}
na=length(a);nb=length(b)
Expand Down
73 changes: 73 additions & 0 deletions test/fftBigFloattests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,77 @@ end
@test norm(idct(c)-idct(map(ComplexF64,c)),Inf) < 10eps()
@test norm(idct(dct(c))-c,Inf) < 1000eps(BigFloat)
@test norm(dct(idct(c))-c,Inf) < 1000eps(BigFloat)

# Make sure we don't accidentally hijack any FFTW plans
for T in (Float32, Float64)
@test plan_fft(rand(BigFloat,10)) isa FastTransforms.DummyPlan
@test plan_fft(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
@test plan_fft(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_fft(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test plan_fft!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_fft!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test !( plan_fft(rand(T,10)) isa FastTransforms.DummyPlan )
@test !( plan_fft(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_fft(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_fft(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_fft!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_fft!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )

@test plan_ifft(rand(T,10)) isa FFTW.ScaledPlan
@test plan_ifft(rand(T,10), 1:1) isa FFTW.ScaledPlan
@test plan_ifft(rand(Complex{T},10)) isa FFTW.ScaledPlan
@test plan_ifft(rand(Complex{T},10), 1:1) isa FFTW.ScaledPlan
@test plan_ifft!(rand(Complex{T},10)) isa FFTW.ScaledPlan
@test plan_ifft!(rand(Complex{T},10), 1:1) isa FFTW.ScaledPlan

@test plan_bfft(rand(BigFloat,10)) isa FastTransforms.DummyPlan
@test plan_bfft(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
@test plan_bfft(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_bfft(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test plan_bfft!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_bfft!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test !( plan_bfft(rand(T,10)) isa FastTransforms.DummyPlan )
@test !( plan_bfft(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_bfft(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_bfft(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_bfft!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_bfft!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )

@test plan_dct(rand(BigFloat,10)) isa FastTransforms.DummyPlan
@test plan_dct(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
@test plan_dct(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_dct(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test plan_dct!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_dct!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test !( plan_dct(rand(T,10)) isa FastTransforms.DummyPlan )
@test !( plan_dct(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_dct(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_dct(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_dct!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_dct!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )

@test plan_idct(rand(BigFloat,10)) isa FastTransforms.DummyPlan
@test plan_idct(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
@test plan_idct(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_idct(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test plan_idct!(rand(Complex{BigFloat},10)) isa FastTransforms.DummyPlan
@test plan_idct!(rand(Complex{BigFloat},10), 1:1) isa FastTransforms.DummyPlan
@test !( plan_idct(rand(T,10)) isa FastTransforms.DummyPlan )
@test !( plan_idct(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_idct(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_idct(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_idct!(rand(Complex{T},10)) isa FastTransforms.DummyPlan )
@test !( plan_idct!(rand(Complex{T},10), 1:1) isa FastTransforms.DummyPlan )

@test plan_rfft(rand(BigFloat,10)) isa FastTransforms.DummyPlan
@test plan_rfft(rand(BigFloat,10), 1:1) isa FastTransforms.DummyPlan
@test plan_brfft(rand(Complex{BigFloat},10), 19) isa FastTransforms.DummyPlan
@test plan_brfft(rand(Complex{BigFloat},10), 19, 1:1) isa FastTransforms.DummyPlan
@test !( plan_rfft(rand(T,10)) isa FastTransforms.DummyPlan )
@test !( plan_rfft(rand(T,10), 1:1) isa FastTransforms.DummyPlan )
@test !( plan_brfft(rand(Complex{T},10), 19) isa FastTransforms.DummyPlan )
@test !( plan_brfft(rand(Complex{T},10), 19, 1:1) isa FastTransforms.DummyPlan )

end

end