Skip to content

Commit 68c5ba4

Browse files
mfherbstMikaelSlevinsky
authored andcommitted
Comply with plan_inv function of AbstractFFTs (#85)
1 parent f092c24 commit 68c5ba4

File tree

4 files changed

+59
-38
lines changed

4 files changed

+59
-38
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,20 @@ version = "0.7.0"
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
77
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
88
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
9-
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
109
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
10+
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1111
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1314
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516
ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24"
1617

1718
[compat]
1819
AbstractFFTs = "0.4"
1920
DSP = "0.6"
20-
FastGaussQuadrature = "0.4"
2121
FFTW = "1"
22+
FastGaussQuadrature = "0.4"
2223
SpecialFunctions = "0.8"
2324
ToeplitzMatrices = "0.6"
2425
julia = "1"

src/FastTransforms.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
11
module FastTransforms
22

3-
using AbstractFFTs, DSP, FastGaussQuadrature, FFTW, Libdl, LinearAlgebra, SpecialFunctions, ToeplitzMatrices
3+
using DSP, FastGaussQuadrature, Libdl, LinearAlgebra, SpecialFunctions, ToeplitzMatrices
4+
using Reexport
5+
@reexport using AbstractFFTs
6+
@reexport using FFTW
47

58
import Base: unsafe_convert, eltype, ndims, adjoint, transpose, show, *, \,
69
inv, size, view
710

811
import Base.GMP: Limb
912
import Base.MPFR: BigFloat, _BigFloat
1013

11-
import AbstractFFTs: Plan
14+
import AbstractFFTs: Plan, ScaledPlan,
15+
fft, ifft, bfft, fft!, ifft!, bfft!,
16+
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
17+
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
18+
fftshift, ifftshift,
19+
rfft_output_size, brfft_output_size,
20+
plan_inv, normalization
21+
22+
import FFTW: dct, dct!, idct, idct!, plan_dct!, plan_idct!,
23+
plan_dct, plan_idct, fftwNumber
1224

1325
import DSP: conv
1426

1527
import FastGaussQuadrature: unweightedgausshermite
1628

17-
import FFTW: dct, dct!, idct, idct!,
18-
plan_fft!, plan_ifft!, plan_dct!, plan_idct!,
19-
plan_fft, plan_ifft, plan_rfft, plan_irfft, plan_dct, plan_idct,
20-
plan_bfft, plan_bfft!, plan_brfft, fftwNumber
21-
2229
import LinearAlgebra: mul!, lmul!, ldiv!
2330

2431
export leg2cheb, cheb2leg, ultra2ultra, jac2jac,

src/fftBigFloat.jl

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const ComplexFloats = Complex{T} where T<:AbstractFloat
99
# To add more types, add them in the union of the function's signature.
1010

1111
function generic_fft(x::Vector{T}) where T<:AbstractFloats
12-
T <: FFTW.fftwNumber && (@warn("Using generic fft for FFTW number type."))
12+
T <: FFTW.fftwNumber && (@warn("Using generic fft for FFTW number type."))
1313
n = length(x)
1414
ispow2(n) && return generic_fft_pow2(x)
1515
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
@@ -37,8 +37,8 @@ end
3737

3838
generic_bfft(x::Vector{T}) where {T <: AbstractFloats} = conj!(generic_fft(conj(x)))
3939
function generic_bfft!(x::Vector{T}) where {T <: AbstractFloats}
40-
x[:] = generic_bfft(x)
41-
return x
40+
x[:] = generic_bfft(x)
41+
return x
4242
end
4343

4444
generic_brfft(v::Vector, n::Integer) = generic_irfft(v, n)*n
@@ -113,9 +113,9 @@ function generic_ifft_pow2(x::Vector{Complex{T}}) where T<:AbstractFloat
113113
end
114114

115115
function generic_dct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
116-
T <: FFTW.fftwNumber && (@warn("Using generic dct for FFTW number type."))
117-
N = length(a)
118-
twoN = convert(T,2) * N
116+
T <: FFTW.fftwNumber && (@warn("Using generic dct for FFTW number type."))
117+
N = length(a)
118+
twoN = convert(T,2) * N
119119
c = generic_fft([a; reverse(a, dims=1)]) # c = generic_fft([a; flipdim(a,1)])
120120
d = c[1:N]
121121
d .*= exp.((-im*convert(T, pi)).*(0:N-1)./twoN)
@@ -126,9 +126,9 @@ end
126126
generic_dct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_dct(complex(a)))
127127

128128
function generic_idct(a::AbstractVector{Complex{T}}) where {T <: AbstractFloat}
129-
T <: FFTW.fftwNumber && (@warn("Using generic idct for FFTW number type."))
130-
N = length(a)
131-
twoN = convert(T,2)*N
129+
T <: FFTW.fftwNumber && (@warn("Using generic idct for FFTW number type."))
130+
N = length(a)
131+
twoN = convert(T,2)*N
132132
b = a * sqrt(twoN)
133133
b[1] = b[1] * sqrt(convert(T,2))
134134
shift = exp.(-im * 2 * convert(T, pi) * (N - convert(T,1)/2) * (0:(2N-1)) / twoN)
@@ -154,32 +154,37 @@ end
154154

155155
# dummy plans
156156
abstract type DummyPlan{T} <: Plan{T} end
157-
struct DummyFFTPlan{T,inplace} <: DummyPlan{T} end
158-
struct DummyiFFTPlan{T,inplace} <: DummyPlan{T} end
159-
struct DummybFFTPlan{T,inplace} <: DummyPlan{T} end
160-
struct DummyDCTPlan{T,inplace} <: DummyPlan{T} end
161-
struct DummyiDCTPlan{T,inplace} <: DummyPlan{T} end
162-
struct DummyrFFTPlan{T,inplace} <: DummyPlan{T}
163-
n::Integer
164-
end
165-
struct DummyirFFTPlan{T,inplace} <: DummyPlan{T}
166-
n::Integer
157+
for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiDCTPlan)
158+
# All plans need an initially undefined pinv field
159+
@eval begin
160+
mutable struct $P{T,inplace} <: DummyPlan{T}
161+
pinv::DummyPlan{T}
162+
$P{T,inplace}() where {T<:AbstractFloats, inplace} = new()
163+
end
164+
end
167165
end
168-
struct DummybrFFTPlan{T,inplace} <: DummyPlan{T}
169-
n::Integer
166+
for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan)
167+
@eval begin
168+
mutable struct $P{T,inplace} <: DummyPlan{T}
169+
n::Integer
170+
pinv::DummyPlan{T}
171+
$P{T,inplace}(n::Integer) where {T<:AbstractFloats, inplace} = new(n)
172+
end
173+
end
170174
end
171175

172176
for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
173177
(:DummyDCTPlan,:DummyiDCTPlan))
174178
@eval begin
175-
Base.inv(::$Plan{T,inplace}) where {T,inplace} = $iPlan{T,inplace}()
176-
Base.inv(::$iPlan{T,inplace}) where {T,inplace} = $Plan{T,inplace}()
179+
plan_inv(::$Plan{T,inplace}) where {T,inplace} = $iPlan{T,inplace}()
180+
plan_inv(::$iPlan{T,inplace}) where {T,inplace} = $Plan{T,inplace}()
177181
end
178182
end
179183

180184
# Specific for rfft, irfft and brfft:
181-
Base.inv(::DummyirFFTPlan{T,inplace}) where {T,inplace} = DummyrFFTPlan{T,Inplace}(p.n)
182-
Base.inv(::DummyrFFTPlan{T,inplace}) where {T,inplace} = DummyirFFTPlan{T,Inplace}(p.n)
185+
plan_inv(p::DummyirFFTPlan{T,inplace}) where {T,inplace} = DummyrFFTPlan{T,Inplace}(p.n)
186+
plan_inv(p::DummyrFFTPlan{T,inplace}) where {T,inplace} = DummyirFFTPlan{T,Inplace}(p.n)
187+
183188

184189

185190
for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!),
@@ -202,14 +207,14 @@ end
202207
*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n)
203208
*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n)
204209
function mul!(C::StridedVector, p::DummyirFFTPlan, x::StridedVector)
205-
C[:] = generic_irfft(x, p.n)
206-
C
210+
C[:] = generic_irfft(x, p.n)
211+
C
207212
end
208213
*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n)
209214
*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n)
210215
function mul!(C::StridedVector, p::DummybrFFTPlan, x::StridedVector)
211-
C[:] = generic_brfft(x, p.n)
212-
C
216+
C[:] = generic_brfft(x, p.n)
217+
C
213218
end
214219

215220

test/fftBigFloattests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ using FastTransforms, FFTW, Test
2626
@test norm(idct(dct(c))-c,Inf) < 1000eps(BigFloat)
2727
@test norm(dct(idct(c))-c,Inf) < 1000eps(BigFloat)
2828

29+
c = randn(ComplexF16, 20)
30+
p = plan_fft(c)
31+
@test inv(p) * (p * c) c
32+
33+
c = randn(ComplexF16, 20)
34+
pinpl = plan_fft!(c)
35+
@test inv(pinpl) * (pinpl * c) c
36+
2937
# Make sure we don't accidentally hijack any FFTW plans
3038
for T in (Float32, Float64)
3139
@test plan_fft(rand(BigFloat,10)) isa FastTransforms.DummyPlan

0 commit comments

Comments
 (0)