Skip to content

Commit e1d6086

Browse files
farm out BigFloat FFT to GenericFFT
preserve `conv` pirating until DSP's PR gets merged and tagged JuliaDSP/DSP.jl#477
1 parent a80a47d commit e1d6086

File tree

4 files changed

+9
-453
lines changed

4 files changed

+9
-453
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
99
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1010
FastTransforms_jll = "34b6f7d7-08f9-5794-9e10-3819e4c7e49a"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
12+
GenericFFT = "a8297547-1b15-4a5a-a998-a2ac5f1cef28"
1213
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -23,6 +24,7 @@ FFTW = "1"
2324
FastGaussQuadrature = "0.4"
2425
FastTransforms_jll = "0.6.0"
2526
FillArrays = "0.9, 0.10, 0.11, 0.12, 0.13"
27+
GenericFFT = "0.1"
2628
Reexport = "0.2, 1.0"
2729
SpecialFunctions = "0.10, 1, 2"
2830
ToeplitzMatrices = "0.6, 0.7"

src/FastTransforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import DSP
77

88
@reexport using AbstractFFTs
99
@reexport using FFTW
10+
@reexport using GenericFFT
1011

1112
import Base: convert, unsafe_convert, eltype, ndims, adjoint, transpose, show,
1213
*, \, inv, length, size, view, getindex

src/fftBigFloat.jl

Lines changed: 4 additions & 345 deletions
Original file line numberDiff line numberDiff line change
@@ -1,345 +1,4 @@
1-
const AbstractFloats = Union{AbstractFloat,Complex{T} where T<:AbstractFloat}
2-
3-
# We use these type definitions for clarity
4-
const RealFloats = T where T<:AbstractFloat
5-
const ComplexFloats = Complex{T} where T<:AbstractFloat
6-
7-
8-
# The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html
9-
# To add more types, add them in the union of the function's signature.
10-
11-
function generic_fft(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
12-
region == 1 && (ret = generic_fft(x))
13-
ret
14-
end
15-
16-
function generic_fft!(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
17-
region == 1 && (x[:] .= generic_fft(x))
18-
x
19-
end
20-
21-
function generic_fft(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
22-
region == 1:1 && (ret = generic_fft(x))
23-
ret
24-
end
25-
26-
function generic_fft!(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
27-
region == 1:1 && (x[:] .= generic_fft(x))
28-
x
29-
end
30-
31-
function generic_fft(x::StridedMatrix{T}, region::Integer) where T<:AbstractFloats
32-
if region == 1
33-
ret = hcat([generic_fft(x[:, j]) for j in 1:size(x, 2)]...)
34-
end
35-
ret
36-
end
37-
38-
function generic_fft!(x::StridedMatrix{T}, region::Integer) where T<:AbstractFloats
39-
if region == 1
40-
for j in 1:size(x, 2)
41-
x[:, j] .= generic_fft(x[:, j])
42-
end
43-
end
44-
x
45-
end
46-
47-
function generic_fft(x::Vector{T}) where T<:AbstractFloats
48-
T <: FFTW.fftwNumber && (@warn("Using generic fft for FFTW number type."))
49-
n = length(x)
50-
ispow2(n) && return generic_fft_pow2(x)
51-
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
52-
Wks = exp.((-im).*convert(T,π).*ks.^2 ./ n)
53-
xq, wq = x.*Wks, conj([exp(-im*convert(T,π)*n);reverse(Wks);Wks[2:end]])
54-
return Wks.*_conv!(xq,wq)[n+1:2n]
55-
end
56-
57-
generic_bfft(x::StridedArray{T, N}, region) where {T <: AbstractFloats, N} = conj!(generic_fft(conj(x), region))
58-
generic_bfft!(x::StridedArray{T, N}, region) where {T <: AbstractFloats, N} = conj!(generic_fft!(conj!(x), region))
59-
generic_ifft(x::StridedArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(length(x), conj!(generic_fft(conj(x), region)))
60-
generic_ifft!(x::StridedArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(length(x), conj!(generic_fft!(conj!(x), region)))
61-
62-
generic_rfft(v::Vector{T}, region) where T<:AbstractFloats = generic_fft(v, region)[1:div(length(v),2)+1]
63-
function generic_irfft(v::Vector{T}, n::Integer, region) where T<:ComplexFloats
64-
@assert n==2length(v)-1
65-
r = Vector{T}(undef, n)
66-
r[1:length(v)]=v
67-
r[length(v)+1:end]=reverse(conj(v[2:end]))
68-
real(generic_ifft(r, region))
69-
end
70-
generic_brfft(v::StridedArray, n::Integer, region) = generic_irfft(v, n, region)*n
71-
72-
function _conv!(u::StridedVector{T}, v::StridedVector{T}) where T<:AbstractFloats
73-
nu = length(u)
74-
nv = length(v)
75-
n = nu + nv - 1
76-
np2 = nextpow(2, n)
77-
append!(u, zeros(T, np2-nu))
78-
append!(v, zeros(T, np2-nv))
79-
y = generic_ifft_pow2(generic_fft_pow2(u).*generic_fft_pow2(v))
80-
#TODO This would not handle Dual/ComplexDual numbers correctly
81-
y = T<:Real ? real(y[1:n]) : y[1:n]
82-
end
83-
84-
conv(u::AbstractArray{T, N}, v::AbstractArray{T, N}) where {T<:AbstractFloat, N} = _conv!(deepcopy(u), deepcopy(v))
85-
conv(u::AbstractArray{T, N}, v::AbstractArray{Complex{T}, N}) where {T<:AbstractFloat, N} = _conv!(complex(deepcopy(u)), deepcopy(v))
86-
conv(u::AbstractArray{Complex{T}, N}, v::AbstractArray{T, N}) where {T<:AbstractFloat, N} = _conv!(deepcopy(u), complex(deepcopy(v)))
87-
conv(u::AbstractArray{Complex{T}, N}, v::AbstractArray{Complex{T}, N}) where {T<:AbstractFloat, N} = _conv!(deepcopy(u), deepcopy(v))
88-
89-
# This is a Cooley-Tukey FFT algorithm inspired by many widely available algorithms including:
90-
# c_radix2.c in the GNU Scientific Library and four1 in the Numerical Recipes in C.
91-
# However, the trigonometric recurrence is improved for greater efficiency.
92-
# The algorithm starts with bit-reversal, then divides and conquers in-place.
93-
function generic_fft_pow2!(x::Vector{T}) where T<:AbstractFloat
94-
n,big2=length(x),2one(T)
95-
nn,j=n÷2,1
96-
for i=1:2:n-1
97-
if j>i
98-
x[j], x[i] = x[i], x[j]
99-
x[j+1], x[i+1] = x[i+1], x[j+1]
100-
end
101-
m = nn
102-
while m 2 && j > m
103-
j -= m
104-
m = m÷2
105-
end
106-
j += m
107-
end
108-
logn = 2
109-
while logn < n
110-
θ=-big2/logn
111-
wtemp = sinpi/2)
112-
wpr, wpi = -2wtemp^2, sinpi(θ)
113-
wr, wi = one(T), zero(T)
114-
for m=1:2:logn-1
115-
for i=m:2logn:n
116-
j=i+logn
117-
mixr, mixi = wr*x[j]-wi*x[j+1], wr*x[j+1]+wi*x[j]
118-
x[j], x[j+1] = x[i]-mixr, x[i+1]-mixi
119-
x[i], x[i+1] = x[i]+mixr, x[i+1]+mixi
120-
end
121-
wr = (wtemp=wr)*wpr-wi*wpi+wr
122-
wi = wi*wpr+wtemp*wpi+wi
123-
end
124-
logn = logn << 1
125-
end
126-
return x
127-
end
128-
129-
function generic_fft_pow2(x::Vector{Complex{T}}) where T<:AbstractFloat
130-
y = interlace(real(x), imag(x))
131-
generic_fft_pow2!(y)
132-
return complex.(y[1:2:end], y[2:2:end])
133-
end
134-
generic_fft_pow2(x::Vector{T}) where T<:AbstractFloat = generic_fft_pow2(complex(x))
135-
136-
function generic_ifft_pow2(x::Vector{Complex{T}}) where T<:AbstractFloat
137-
y = interlace(real(x), -imag(x))
138-
generic_fft_pow2!(y)
139-
return ldiv!(length(x), conj!(complex.(y[1:2:end], y[2:2:end])))
140-
end
141-
142-
function generic_dct(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
143-
region == 1 && (ret = generic_dct(x))
144-
ret
145-
end
146-
147-
function generic_dct!(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
148-
region == 1 && (x[:] .= generic_dct(x))
149-
x
150-
end
151-
152-
function generic_idct(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
153-
region == 1 && (ret = generic_idct(x))
154-
ret
155-
end
156-
157-
function generic_idct!(x::StridedVector{T}, region::Integer) where T<:AbstractFloats
158-
region == 1 && (x[:] .= generic_idct(x))
159-
x
160-
end
161-
162-
function generic_dct(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
163-
region == 1:1 && (ret = generic_dct(x))
164-
ret
165-
end
166-
167-
function generic_dct!(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
168-
region == 1:1 && (x[:] .= generic_dct(x))
169-
x
170-
end
171-
172-
function generic_idct(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
173-
region == 1:1 && (ret = generic_idct(x))
174-
ret
175-
end
176-
177-
function generic_idct!(x::StridedVector{T}, region::UnitRange{I}) where {T<:AbstractFloats, I<:Integer}
178-
region == 1:1 && (x[:] .= generic_idct(x))
179-
x
180-
end
181-
182-
function generic_dct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
183-
T <: FFTW.fftwNumber && (@warn("Using generic dct for FFTW number type."))
184-
N = length(a)
185-
twoN = convert(T,2) * N
186-
c = generic_fft([a; reverse(a, dims=1)]) # c = generic_fft([a; flipdim(a,1)])
187-
d = c[1:N]
188-
d .*= exp.((-im*convert(T, pi)).*(0:N-1)./twoN)
189-
d[1] = d[1] / sqrt(convert(T, 2))
190-
lmul!(inv(sqrt(twoN)), d)
191-
end
192-
193-
generic_dct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_dct(complex(a)))
194-
195-
function generic_idct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
196-
T <: FFTW.fftwNumber && (@warn("Using generic idct for FFTW number type."))
197-
N = length(a)
198-
twoN = convert(T,2)*N
199-
b = a * sqrt(twoN)
200-
b[1] = b[1] * sqrt(convert(T,2))
201-
shift = exp.(-im * 2 * convert(T, pi) * (N - convert(T,1)/2) * (0:(2N-1)) / twoN)
202-
b = [b; 0; -reverse(b[2:end], dims=1)] .* shift # b = [b; 0; -flipdim(b[2:end],1)] .* shift
203-
c = ifft(b)
204-
reverse(c[1:N]; dims=1)#flipdim(c[1:N],1)
205-
end
206-
207-
generic_idct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_idct(complex(a)))
208-
209-
210-
# These lines mimick the corresponding ones in FFTW/src/dct.jl, but with
211-
# AbstractFloat rather than fftwNumber.
212-
for f in (:dct, :dct!, :idct, :idct!)
213-
pf = Symbol("plan_", f)
214-
@eval begin
215-
$f(x::AbstractArray{<:AbstractFloats}) = $pf(x) * x
216-
$f(x::AbstractArray{<:AbstractFloats}, region) = $pf(x, region) * x
217-
end
218-
end
219-
220-
# dummy plans
221-
abstract type DummyPlan{T} <: Plan{T} end
222-
for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiDCTPlan)
223-
# All plans need an initially undefined pinv field
224-
@eval begin
225-
mutable struct $P{T,inplace,G} <: DummyPlan{T}
226-
region::G # region (iterable) of dims that are transformed
227-
pinv::DummyPlan{T}
228-
$P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region)
229-
end
230-
end
231-
end
232-
for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan)
233-
@eval begin
234-
mutable struct $P{T,inplace,G} <: DummyPlan{T}
235-
n::Integer
236-
region::G # region (iterable) of dims that are transformed
237-
pinv::DummyPlan{T}
238-
$P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region)
239-
end
240-
end
241-
end
242-
243-
for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
244-
(:DummyDCTPlan,:DummyiDCTPlan))
245-
@eval begin
246-
plan_inv(p::$Plan{T,inplace,G}) where {T,inplace,G} = $iPlan{T,inplace,G}(p.region)
247-
plan_inv(p::$iPlan{T,inplace,G}) where {T,inplace,G} = $Plan{T,inplace,G}(p.region)
248-
end
249-
end
250-
251-
# Specific for rfft, irfft and brfft:
252-
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,Inplace,G}(p.n, p.region)
253-
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,Inplace,G}(p.n, p.region)
254-
255-
256-
257-
for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
258-
(:DummybFFTPlan,:generic_bfft,:generic_bfft!),
259-
(:DummyiFFTPlan,:generic_ifft,:generic_ifft!),
260-
(:DummyrFFTPlan,:generic_rfft,:generic_rfft!),
261-
(:DummyDCTPlan,:generic_dct,:generic_dct!),
262-
(:DummyiDCTPlan,:generic_idct,:generic_idct!))
263-
@eval begin
264-
*(p::$Plan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff!(x, p.region)
265-
*(p::$Plan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff(x, p.region)
266-
function mul!(C::StridedVector, p::$Plan, x::StridedVector)
267-
C[:] = $ff(x, p.region)
268-
C
269-
end
270-
end
271-
end
272-
273-
# Specific for irfft and brfft:
274-
*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n, p.region)
275-
*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n, p.region)
276-
function mul!(C::StridedVector, p::DummyirFFTPlan, x::StridedVector)
277-
C[:] = generic_irfft(x, p.n, p.region)
278-
C
279-
end
280-
*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n, p.region)
281-
*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n, p.region)
282-
function mul!(C::StridedVector, p::DummybrFFTPlan, x::StridedVector)
283-
C[:] = generic_brfft(x, p.n, p.region)
284-
C
285-
end
286-
287-
288-
# We override these for AbstractFloat, so that conversion from reals to
289-
# complex numbers works for any AbstractFloat (instead of only BlasFloat's)
290-
AbstractFFTs.complexfloat(x::StridedArray{Complex{<:AbstractFloat}}) = x
291-
AbstractFFTs.realfloat(x::StridedArray{<:Real}) = x
292-
# We override this one in order to avoid throwing an error that the type is
293-
# unsupported (as defined in AbstractFFTs)
294-
AbstractFFTs._fftfloat(::Type{T}) where {T <: AbstractFloat} = T
295-
296-
297-
# We intercept the calls to plan_X(x, region) below.
298-
# In order not to capture any calls that should go to FFTW, we have to be
299-
# careful about the typing, so that the calls to FFTW remain more specific.
300-
# This is the reason for using StridedArray below. We also have to carefully
301-
# distinguish between real and complex arguments.
302-
303-
plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false,typeof(region)}(region)
304-
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true,typeof(region)}(region)
305-
306-
plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false,typeof(region)}(region)
307-
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true,typeof(region)}(region)
308-
309-
# The ifft plans are automatically provided in terms of the bfft plans above.
310-
# plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false,typeof(region)}(region)
311-
# plan_ifft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},true,typeof(region)}(region)
312-
313-
plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false,typeof(region)}(region)
314-
plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true,typeof(region)}(region)
315-
316-
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region)
317-
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region)
318-
319-
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false,typeof(region)}(length(x), region)
320-
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
321-
322-
# A plan for irfft is created in terms of a plan for brfft.
323-
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
324-
325-
# These don't exist for now:
326-
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()
327-
# plan_irfft!(x::StridedArray{T},n::Integer) where {T <: RealFloats} = DummyirFFTPlan{Complex{real(T)},true}()
328-
329-
function interlace(a::Vector{S},b::Vector{V}) where {S<:Number,V<:Number}
330-
na=length(a);nb=length(b)
331-
T=promote_type(S,V)
332-
if nbna
333-
ret=zeros(T,2nb)
334-
ret[1:2:1+2*(na-1)]=a
335-
ret[2:2:end]=b
336-
ret
337-
else
338-
ret=zeros(T,2na-1)
339-
ret[1:2:end]=a
340-
if !isempty(b)
341-
ret[2:2:2+2*(nb-1)]=b
342-
end
343-
ret
344-
end
345-
end
1+
conv(u::AbstractArray{T, N}, v::AbstractArray{T, N}) where {T<:AbstractFloat, N} = GenericFFT._conv!(deepcopy(u), deepcopy(v))
2+
conv(u::AbstractArray{T, N}, v::AbstractArray{Complex{T}, N}) where {T<:AbstractFloat, N} = GenericFFT._conv!(complex(deepcopy(u)), deepcopy(v))
3+
conv(u::AbstractArray{Complex{T}, N}, v::AbstractArray{T, N}) where {T<:AbstractFloat, N} = GenericFFT._conv!(deepcopy(u), complex(deepcopy(v)))
4+
conv(u::AbstractArray{Complex{T}, N}, v::AbstractArray{Complex{T}, N}) where {T<:AbstractFloat, N} = GenericFFT._conv!(deepcopy(u), deepcopy(v))

0 commit comments

Comments
 (0)