Skip to content

Generalise Clenshaw #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ deps/build.log
deps/libfasttransforms.*
.DS_Store
deps/FastTransforms/
Manifest.toml
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FastTransforms_jll = "34b6f7d7-08f9-5794-9e10-3819e4c7e49a"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -23,6 +24,7 @@ DSP = "0.6"
FFTW = "1"
FastGaussQuadrature = "0.4"
FastTransforms_jll = "0.3.2"
FillArrays = "0.8"
Reexport = "0.2"
SpecialFunctions = "0.8, 0.9, 0.10"
ToeplitzMatrices = "0.6"
Expand Down
4 changes: 3 additions & 1 deletion src/FastTransforms.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module FastTransforms

using FastGaussQuadrature, LinearAlgebra
using Reexport, SpecialFunctions, ToeplitzMatrices
using Reexport, SpecialFunctions, ToeplitzMatrices, FillArrays

import DSP

Expand Down Expand Up @@ -94,4 +94,6 @@ lgamma(x) = logabsgamma(x)[1]

include("specialfunctions.jl")

include("clenshaw.jl")

end # module
104 changes: 104 additions & 0 deletions src/clenshaw.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@

"""
clenshaw!(c, A, B, C, x)

evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF,
overwriting `x` with the results.
"""
clenshaw!(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector) =
clenshaw!(c, A, B, C, x, Ones{eltype(x)}(length(x)), x)


"""
clenshaw!(c, A, B, C, x, ϕ₀, f)

evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF and ϕ₀ is the zeroth coefficient,
overwriting `f` with the results.
"""
function clenshaw!(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector, ϕ₀::AbstractVector, f::AbstractVector)
f .= ϕ₀ .* clenshaw.(Ref(c), Ref(A), Ref(B), Ref(C), x)
end

"""
clenshaw(c, A, B, C, x)

evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF.
`x` may also be a single `Number`.
"""

function clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::Number)
N = length(c)
T = promote_type(eltype(c),eltype(A),eltype(B),eltype(C),typeof(x))
if length(A) < N || length(B) < N || length(C) < N
throw(ArgumentError("A, B, C must contain at least $N entries"))
end
N == 0 && return zero(T)
@inbounds begin
bk2 = zero(T)
bk1 = convert(T,c[N])
for k = N-1:-1:1
bk1,bk2 = muladd(muladd(A[k],x,B[k]),bk1,muladd(-C[k],bk2,c[k])),bk1
end
end
bk1
end


clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector) =
clenshaw!(c, A, B, C, copy(x))

###
# Chebyshev T special cases
###

"""
clenshaw!(c, x)

evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at points `x`,
overwriting `x` with the results.
"""
clenshaw!(c::AbstractVector, x::AbstractVector) = clenshaw!(c, x, x)


"""
clenshaw!(c, x, f)

evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at points `x`,
overwriting `f` with the results.
"""
function clenshaw!(c::AbstractVector, x::AbstractVector, f::AbstractVector)
f .= clenshaw.(Ref(c), x)
end

"""
clenshaw(c, x)

evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at the points `x`.
`x` may also be a single `Number`.
"""
function clenshaw(c::AbstractVector, x::Number)
N,T = length(c),promote_type(eltype(c),typeof(x))
if N == 0
return zero(T)
elseif N == 1 # avoid issues with NaN x
return first(c)*one(x)
end

y = 2x
bk1,bk2 = zero(T),zero(T)
@inbounds begin
for k = N:-1:2
bk1,bk2 = muladd(y,bk1,c[k]-bk2),bk1
end
muladd(x,bk1,c[1]-bk2)
end
end

clenshaw(c::AbstractVector, x::AbstractVector) = clenshaw!(c, copy(x))

13 changes: 11 additions & 2 deletions src/libfasttransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,42 @@ set_num_threads(n::Integer) = ccall((:ft_set_num_threads, libfasttransforms), Cv
function horner!(c::Vector{Float64}, x::Vector{Float64}, f::Vector{Float64})
@assert length(x) == length(f)
ccall((:ft_horner, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, length(x), x, f)
f
end

function horner!(c::Vector{Float32}, x::Vector{Float32}, f::Vector{Float32})
@assert length(x) == length(f)
ccall((:ft_hornerf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Cint, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, length(x), x, f)
f
end

function clenshaw!(c::Vector{Float64}, x::Vector{Float64}, f::Vector{Float64})
@assert length(x) == length(f)
ccall((:ft_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, length(x), x, f)
f
end

function clenshaw!(c::Vector{Float32}, x::Vector{Float32}, f::Vector{Float32})
@assert length(x) == length(f)
ccall((:ft_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Cint, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, length(x), x, f)
f
end

function clenshaw!(c::Vector{Float64}, A::Vector{Float64}, B::Vector{Float64}, C::Vector{Float64}, x::Vector{Float64}, phi0::Vector{Float64}, f::Vector{Float64})
@assert length(c) == length(A) == length(B) == length(C)-1
@assert length(x) == length(phi0) == length(f)
N = length(c)
if length(A) < N || length(B) < N || length(C) < N
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MikaelSlevinsky Do you agree that length(C) < N is fine? Before it seemed like it required an extra coefficient for no reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of the DLMF notation for C. You don't need C[0] to get p1 because p_{-1} == 0, but Clenshaw needs C[n] but only A[n-1] and B[n-1] (using C array indexing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could get away with pointing to the right place, but I thought an extra entry to match notation would be reasonable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that makes sense to require allocating an extra entry that is never used. I guess we can do pointer(c)-sizeof(T) to work around this?

throw(ArgumentError("A, B, C must contain at least $N entries"))
end
length(x) == length(phi0) == length(f) || throw(ArgumentError("Dimensions must match"))
ccall((:ft_orthogonal_polynomial_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, A, B, C, length(x), x, phi0, f)
f
end

function clenshaw!(c::Vector{Float32}, A::Vector{Float32}, B::Vector{Float32}, C::Vector{Float32}, x::Vector{Float32}, phi0::Vector{Float32}, f::Vector{Float32})
@assert length(c) == length(A) == length(B) == length(C)-1
@assert length(x) == length(phi0) == length(f)
ccall((:ft_orthogonal_polynomial_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, A, B, C, length(x), x, phi0, f)
f
end

const LEG2CHEB = 0
Expand Down
41 changes: 41 additions & 0 deletions test/clenshawtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using FastTransforms, Test
import FastTransforms: clenshaw, clenshaw!

@testset "clenshaw" begin
@testset "Chebyshev" begin
c = [1,2,3]
cf = float(c)
@test @inferred(clenshaw(c,1)) ≡ 1 + 2 + 3
@test @inferred(clenshaw(c,0)) ≡ 1 + 0 - 3
@test @inferred(clenshaw(c,0.1)) == 1 + 2*0.1 + 3*cos(2acos(0.1))
@test @inferred(clenshaw(c,[-1,0,1])) == clenshaw!(c,[-1,0,1]) == [2,-2,6]
@test clenshaw(c,[-1,0,1]) isa Vector{Int}
@test @inferred(clenshaw(Float64[],1)) ≡ 0.0

x = [1,0,0.1]
@test @inferred(clenshaw(c,x)) ≈ @inferred(clenshaw!(c,copy(x))) ≈
@inferred(clenshaw!(c,x,similar(x))) ≈
@inferred(clenshaw(cf,x)) ≈ @inferred(clenshaw!(cf,copy(x))) ≈
@inferred(clenshaw!(cf,x,similar(x))) ≈ [6,-2,-1.74]

end

@testset "general" begin
@testset "Chebyshev-as-general" begin
c, A, B, C = [1,2,3], [1,2,2], fill(0,3), fill(1,3)
cf, Af, Bf, Cf = float(c), float(A), float(B), float(C)
@test @inferred(clenshaw(c, A, B, C, 1)) ≡ 6
@test @inferred(clenshaw(c, A, B, C, 0.1)) ≡ -1.74
@test @inferred(clenshaw([1,2,3], A, B, C, [-1,0,1])) == clenshaw!([1,2,3],A, B, C, [-1,0,1]) == [2,-2,6]
@test clenshaw(c, A, B, C, [-1,0,1]) isa Vector{Int}
@test @inferred(clenshaw(Float64[], A, B, C, 1)) ≡ 0.0

x = [1,0,0.1]
@test @inferred(clenshaw(c,A,B,C,x)) ≈ @inferred(clenshaw!(c,A,B,C,copy(x))) ≈
@inferred(clenshaw!(c,A,B,C,x,one.(x),similar(x))) ≈
@inferred(clenshaw!(cf,Af,Bf,Cf,x,one.(x),similar(x))) ≈
@inferred(clenshaw([1.,2,3],A,B,C,x)) ≈
@inferred(clenshaw!([1.,2,3],A,B,C,copy(x))) ≈ [6,-2,-1.74]
end
end
end
8 changes: 4 additions & 4 deletions test/libfasttransformstests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ FastTransforms.set_num_threads(ceil(Int, Base.Sys.CPU_THREADS/2))
for T in (Float32, Float64)
c = one(T) ./ (1:n)
x = collect(-1 .+ 2*(0:n-1)/T(n))
f = zero(x)
FastTransforms.horner!(c, x, f)
f = similar(x)
@test FastTransforms.horner!(c, x, f) == f
fd = T[sum(c[k]*x^(k-1) for k in 1:length(c)) for x in x]
@test f ≈ fd
FastTransforms.clenshaw!(c, x, f)
@test FastTransforms.clenshaw!(c, x, f) == f
fd = T[sum(c[k]*cos((k-1)*acos(x)) for k in 1:length(c)) for x in x]
@test f ≈ fd
A = T[(2k+one(T))/(k+one(T)) for k in 0:length(c)-1]
B = T[zero(T) for k in 0:length(c)-1]
C = T[k/(k+one(T)) for k in 0:length(c)]
phi0 = ones(T, length(x))
c = cheb2leg(c)
FastTransforms.clenshaw!(c, A, B, C, x, phi0, f)
@test FastTransforms.clenshaw!(c, A, B, C, x, phi0, f) == f
@test f ≈ fd
end

Expand Down
10 changes: 1 addition & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
using FastTransforms, LinearAlgebra, Test

include("specialfunctionstests.jl")

include("chebyshevtests.jl")

include("quadraturetests.jl")

include("libfasttransformstests.jl")

include("nuffttests.jl")

include("fftBigFloattests.jl")

include("paduatests.jl")

include("gaunttests.jl")

include("hermitetests.jl")

include("toeplitztests.jl")
include("clenshawtests.jl")