Skip to content

Generalise Clenshaw #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ deps/build.log
deps/libfasttransforms.*
.DS_Store
deps/FastTransforms/
Manifest.toml
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FastTransforms_jll = "34b6f7d7-08f9-5794-9e10-3819e4c7e49a"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -23,6 +24,7 @@ DSP = "0.6"
FFTW = "1"
FastGaussQuadrature = "0.4"
FastTransforms_jll = "0.3.2"
FillArrays = "0.8"
Reexport = "0.2"
SpecialFunctions = "0.8, 0.9, 0.10"
ToeplitzMatrices = "0.6"
Expand Down
6 changes: 5 additions & 1 deletion src/FastTransforms.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module FastTransforms

using FastGaussQuadrature, LinearAlgebra
using Reexport, SpecialFunctions, ToeplitzMatrices
using Reexport, SpecialFunctions, ToeplitzMatrices, FillArrays

import DSP

Expand All @@ -28,6 +28,8 @@ import FFTW: dct, dct!, idct, idct!, plan_dct!, plan_idct!,

import FastGaussQuadrature: unweightedgausshermite

import FillArrays: AbstractFill, getindex_value

import LinearAlgebra: mul!, lmul!, ldiv!

export leg2cheb, cheb2leg, ultra2ultra, jac2jac,
Expand Down Expand Up @@ -94,4 +96,6 @@ lgamma(x) = logabsgamma(x)[1]

include("specialfunctions.jl")

include("clenshaw.jl")

end # module
152 changes: 152 additions & 0 deletions src/clenshaw.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
forwardrecurrence!(v, A, B, C, x)

evaluates the orthogonal polynomials at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF,
overwriting `v` with the results.
"""
function forwardrecurrence!(v::AbstractVector{T}, A::AbstractVector, B::AbstractVector, C::AbstractVector, x) where T
N = length(v)
N == 0 && return v
length(A)+1 ≥ N && length(B)+1 ≥ N && length(C)+1 ≥ N || throw(ArgumentError("A, B, C must contain at least $(N-1) entries"))
p0 = one(T) # assume OPs are normalized to one for no
p1 = convert(T, N == 1 ? p0 : A[1]x + B[1]) # avoid accessing A[1]/B[1] if empty
_forwardrecurrence!(v, A, B, C, x, p0, p1)
end


Base.@propagate_inbounds _forwardrecurrence_next(n, A, B, C, x, p0, p1) = muladd(muladd(A[n],x,B[n]), p1, -C[n]*p0)
# special case for B[n] == 0
Base.@propagate_inbounds _forwardrecurrence_next(n, A, ::Zeros, C, x, p0, p1) = muladd(A[n]*x, p1, -C[n]*p0)
# special case for Chebyshev U
Base.@propagate_inbounds _forwardrecurrence_next(n, A::AbstractFill, ::Zeros, C::Ones, x, p0, p1) = muladd(getindex_value(A)*x, p1, -p0)


# this supports adaptivity: we can populate `v` for large `n`
function _forwardrecurrence!(v::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x, p0, p1)
N = length(v)
N == 0 && return v
v[1] = p0
N == 1 && return v
v[2] = p1
@inbounds for n = 2:N-1
p1,p0 = _forwardrecurrence_next(n, A, B, C, x, p0, p1),p1
v[n+1] = p1
end
v
end



forwardrecurrence(N::Integer, A::AbstractVector, B::AbstractVector, C::AbstractVector, x) =
forwardrecurrence!(Vector{promote_type(eltype(A),eltype(B),eltype(C),typeof(x))}(undef, N), A, B, C, x)


"""
clenshaw!(c, A, B, C, x)

evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF,
overwriting `x` with the results.
"""
clenshaw!(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector) =
clenshaw!(c, A, B, C, x, Ones{eltype(x)}(length(x)), x)


"""
clenshaw!(c, A, B, C, x, ϕ₀, f)

evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF and ϕ₀ is the zeroth coefficient,
overwriting `f` with the results.
"""
function clenshaw!(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector, ϕ₀::AbstractVector, f::AbstractVector)
f .= ϕ₀ .* clenshaw.(Ref(c), Ref(A), Ref(B), Ref(C), x)
end


@inline _clenshaw_next(n, A, B, C, x, c, bn1, bn2) = muladd(muladd(A[n],x,B[n]), bn1, muladd(-C[n+1],bn2,c[n]))
@inline _clenshaw_next(n, A, ::Zeros, C, x, c, bn1, bn2) = muladd(A[n]*x, bn1, muladd(-C[n+1],bn2,c[n]))
# Chebyshev U
@inline _clenshaw_next(n, A::AbstractFill, ::Zeros, C::Ones, x, c, bn1, bn2) = muladd(getindex_value(A)*x, bn1, -bn2+c[n])

"""
clenshaw(c, A, B, C, x)

evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
as defined in DLMF.
`x` may also be a single `Number`.
"""

function clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::Number)
N = length(c)
T = promote_type(eltype(c),eltype(A),eltype(B),eltype(C),typeof(x))
@boundscheck check_clenshaw_recurrences(N, A, B, C)
N == 0 && return zero(T)
@inbounds begin
bn2 = zero(T)
bn1 = convert(T,c[N])
for n = N-1:-1:1
bn1,bn2 = _clenshaw_next(n, A, B, C, x, c, bn1, bn2),bn1
end
end
bn1
end


clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector) =
clenshaw!(c, A, B, C, copy(x))

###
# Chebyshev T special cases
###

"""
clenshaw!(c, x)

evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at points `x`,
overwriting `x` with the results.
"""
clenshaw!(c::AbstractVector, x::AbstractVector) = clenshaw!(c, x, x)


"""
clenshaw!(c, x, f)

evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at points `x`,
overwriting `f` with the results.
"""
function clenshaw!(c::AbstractVector, x::AbstractVector, f::AbstractVector)
f .= clenshaw.(Ref(c), x)
end

"""
clenshaw(c, x)

evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at the points `x`.
`x` may also be a single `Number`.
"""
function clenshaw(c::AbstractVector, x::Number)
N,T = length(c),promote_type(eltype(c),typeof(x))
if N == 0
return zero(T)
elseif N == 1 # avoid issues with NaN x
return first(c)*one(x)
end

y = 2x
bk1,bk2 = zero(T),zero(T)
@inbounds begin
for k = N:-1:2
bk1,bk2 = muladd(y,bk1,c[k]-bk2),bk1
end
muladd(x,bk1,c[1]-bk2)
end
end

clenshaw(c::AbstractVector, x::AbstractVector) = clenshaw!(c, copy(x))

34 changes: 26 additions & 8 deletions src/libfasttransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,51 @@ set_num_threads(n::Integer) = ccall((:ft_set_num_threads, libfasttransforms), Cv
function horner!(c::Vector{Float64}, x::Vector{Float64}, f::Vector{Float64})
@assert length(x) == length(f)
ccall((:ft_horner, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, length(x), x, f)
f
end

function horner!(c::Vector{Float32}, x::Vector{Float32}, f::Vector{Float32})
@assert length(x) == length(f)
ccall((:ft_hornerf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Cint, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, length(x), x, f)
f
end

function check_clenshaw_recurrences(N, A, B, C)
if length(A) < N || length(B) < N || length(C) < N+1
throw(ArgumentError("A, B must contain at least $N entries and C must contain at least $(N+1) entrie"))
end
end

function check_clenshaw_points(x, ϕ₀, f)
length(x) == length(ϕ₀) == length(f) || throw(ArgumentError("Dimensions must match"))
end

function clenshaw!(c::Vector{Float64}, x::Vector{Float64}, f::Vector{Float64})
@assert length(x) == length(f)
ccall((:ft_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, length(x), x, f)
f
end

function clenshaw!(c::Vector{Float32}, x::Vector{Float32}, f::Vector{Float32})
@assert length(x) == length(f)
ccall((:ft_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Cint, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, length(x), x, f)
f
end

function clenshaw!(c::Vector{Float64}, A::Vector{Float64}, B::Vector{Float64}, C::Vector{Float64}, x::Vector{Float64}, phi0::Vector{Float64}, f::Vector{Float64})
@assert length(c) == length(A) == length(B) == length(C)-1
@assert length(x) == length(phi0) == length(f)
ccall((:ft_orthogonal_polynomial_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, A, B, C, length(x), x, phi0, f)
function clenshaw!(c::Vector{Float64}, A::Vector{Float64}, B::Vector{Float64}, C::Vector{Float64}, x::Vector{Float64}, ϕ₀::Vector{Float64}, f::Vector{Float64})
N = length(c)
@boundscheck check_clenshaw_recurrences(N, A, B, C)
@boundscheck check_clenshaw_points(x, ϕ₀, f)
ccall((:ft_orthogonal_polynomial_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), N, c, 1, A, B, C, length(x), x, ϕ₀, f)
f
end

function clenshaw!(c::Vector{Float32}, A::Vector{Float32}, B::Vector{Float32}, C::Vector{Float32}, x::Vector{Float32}, phi0::Vector{Float32}, f::Vector{Float32})
@assert length(c) == length(A) == length(B) == length(C)-1
@assert length(x) == length(phi0) == length(f)
ccall((:ft_orthogonal_polynomial_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, A, B, C, length(x), x, phi0, f)
function clenshaw!(c::Vector{Float32}, A::Vector{Float32}, B::Vector{Float32}, C::Vector{Float32}, x::Vector{Float32}, ϕ₀::Vector{Float32}, f::Vector{Float32})
N = length(c)
@boundscheck check_clenshaw_recurrences(N, A, B, C)
@boundscheck check_clenshaw_points(x, ϕ₀, f)
ccall((:ft_orthogonal_polynomial_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}), N, c, 1, A, B, C, length(x), x, ϕ₀, f)
f
end

const LEG2CHEB = 0
Expand Down
106 changes: 106 additions & 0 deletions test/clenshawtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using FastTransforms, FillArrays, Test
import FastTransforms: clenshaw, clenshaw!, forwardrecurrence!, forwardrecurrence

@testset "clenshaw" begin
@testset "Chebyshev T" begin
c = [1,2,3]
cf = float(c)
@test @inferred(clenshaw(c,1)) ≡ 1 + 2 + 3
@test @inferred(clenshaw(c,0)) ≡ 1 + 0 - 3
@test @inferred(clenshaw(c,0.1)) == 1 + 2*0.1 + 3*cos(2acos(0.1))
@test @inferred(clenshaw(c,[-1,0,1])) == clenshaw!(c,[-1,0,1]) == [2,-2,6]
@test clenshaw(c,[-1,0,1]) isa Vector{Int}
@test @inferred(clenshaw(Float64[],1)) ≡ 0.0

x = [1,0,0.1]
@test @inferred(clenshaw(c,x)) ≈ @inferred(clenshaw!(c,copy(x))) ≈
@inferred(clenshaw!(c,x,similar(x))) ≈
@inferred(clenshaw(cf,x)) ≈ @inferred(clenshaw!(cf,copy(x))) ≈
@inferred(clenshaw!(cf,x,similar(x))) ≈ [6,-2,-1.74]
end

@testset "Chebyshev U" begin
N = 5
A, B, C = Fill(2,N-1), Zeros{Int}(N-1), Ones{Int}(N)
@testset "forwardrecurrence!" begin
@test @inferred(forwardrecurrence(N, A, B, C, 1)) == @inferred(forwardrecurrence!(Vector{Int}(undef,N), A, B, C, 1)) == 1:N
@test forwardrecurrence!(Vector{Int}(undef,N), A, B, C, -1) == (-1) .^ (0:N-1) .* (1:N)
@test forwardrecurrence(N, A, B, C, 0.1) ≈ forwardrecurrence!(Vector{Float64}(undef,N), A, B, C, 0.1) ≈
sin.((1:N) .* acos(0.1)) ./ sqrt(1-0.1^2)
end

c = [1,2,3]
@test c'forwardrecurrence(3, A, B, C, 0.1) ≈ clenshaw([1,2,3], A, B, C, 0.1) ≈
1 + (2sin(2acos(0.1)) + 3sin(3acos(0.1)))/sqrt(1-0.1^2)
end

@testset "Chebyshev-as-general" begin
@testset "forwardrecurrence!" begin
N = 5
A, B, C = [1; fill(2,N-2)], fill(0,N-1), fill(1,N)
Af, Bf, Cf = float(A), float(B), float(C)
@test forwardrecurrence(N, A, B, C, 1) == forwardrecurrence!(Vector{Int}(undef,N), A, B, C, 1) == ones(Int,N)
@test forwardrecurrence!(Vector{Int}(undef,N), A, B, C, -1) == (-1) .^ (0:N-1)
@test forwardrecurrence(N, A, B, C, 0.1) ≈ forwardrecurrence!(Vector{Float64}(undef,N), A, B, C, 0.1) ≈ cos.((0:N-1) .* acos(0.1))
end

c, A, B, C = [1,2,3], [1,2,2], fill(0,3), fill(1,4)
cf, Af, Bf, Cf = float(c), float(A), float(B), float(C)
@test @inferred(clenshaw(c, A, B, C, 1)) ≡ 6
@test @inferred(clenshaw(c, A, B, C, 0.1)) ≡ -1.74
@test @inferred(clenshaw([1,2,3], A, B, C, [-1,0,1])) == clenshaw!([1,2,3],A, B, C, [-1,0,1]) == [2,-2,6]
@test clenshaw(c, A, B, C, [-1,0,1]) isa Vector{Int}
@test @inferred(clenshaw(Float64[], A, B, C, 1)) ≡ 0.0

x = [1,0,0.1]
@test @inferred(clenshaw(c, A, B, C, x)) ≈ @inferred(clenshaw!(c, A, B, C, copy(x))) ≈
@inferred(clenshaw!(c, A, B, C, x, one.(x), similar(x))) ≈
@inferred(clenshaw!(cf, Af, Bf, Cf, x, one.(x),similar(x))) ≈
@inferred(clenshaw([1.,2,3], A, B, C, x)) ≈
@inferred(clenshaw!([1.,2,3], A, B, C, copy(x))) ≈ [6,-2,-1.74]
end

@testset "Legendre" begin
@testset "Float64" begin
N = 5
n = 0:N-1
A = (2n .+ 1) ./ (n .+ 1)
B = zeros(N)
C = n ./ (n .+ 1)
v_1 = forwardrecurrence(N, A, B, C, 1)
v_f = forwardrecurrence(N, A, B, C, 0.1)
@test v_1 ≈ ones(N)
@test forwardrecurrence(N, A, B, C, -1) ≈ (-1) .^ (0:N-1)
@test v_f ≈ [1,0.1,-0.485,-0.1475,0.3379375]

n = 0:N # need extra entry for C in Clenshaw
C = n ./ (n .+ 1)
for j = 1:N
c = [zeros(j-1); 1]
@test clenshaw(c, A, B, C, 1) ≈ v_1[j] # Julia code
@test clenshaw(c, A, B, C, 0.1) ≈ v_f[j] # Julia code
@test clenshaw!(c, A, B, C, [1.0,0.1], [1.0,1.0], [0.0,0.0]) ≈ [v_1[j],v_f[j]] # libfasttransforms
end
end

@testset "BigFloat" begin
N = 5
n = BigFloat(0):N-1
A = (2n .+ 1) ./ (n .+ 1)
B = zeros(N)
C = n ./ (n .+ 1)
@test forwardrecurrence(N, A, B, C, parse(BigFloat,"0.1")) ≈ [1,big"0.1",big"-0.485",big"-0.1475",big"0.3379375"]
end
end

@testset "Int" begin
N = 10; A = 1:10; B = 2:11; C = range(3; step=2, length=N+1)
v_i = forwardrecurrence(N, A, B, C, 1)
v_f = forwardrecurrence(N, A, B, C, 0.1)
@test v_i isa Vector{Int}
@test v_f isa Vector{Float64}

j = 3
clenshaw([zeros(Int,j-1); 1; zeros(Int,N-j)], A, B, C, 1) == v_i[j]
end
end
8 changes: 4 additions & 4 deletions test/libfasttransformstests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ FastTransforms.set_num_threads(ceil(Int, Base.Sys.CPU_THREADS/2))
for T in (Float32, Float64)
c = one(T) ./ (1:n)
x = collect(-1 .+ 2*(0:n-1)/T(n))
f = zero(x)
FastTransforms.horner!(c, x, f)
f = similar(x)
@test FastTransforms.horner!(c, x, f) == f
fd = T[sum(c[k]*x^(k-1) for k in 1:length(c)) for x in x]
@test f ≈ fd
FastTransforms.clenshaw!(c, x, f)
@test FastTransforms.clenshaw!(c, x, f) == f
fd = T[sum(c[k]*cos((k-1)*acos(x)) for k in 1:length(c)) for x in x]
@test f ≈ fd
A = T[(2k+one(T))/(k+one(T)) for k in 0:length(c)-1]
B = T[zero(T) for k in 0:length(c)-1]
C = T[k/(k+one(T)) for k in 0:length(c)]
phi0 = ones(T, length(x))
c = cheb2leg(c)
FastTransforms.clenshaw!(c, A, B, C, x, phi0, f)
@test FastTransforms.clenshaw!(c, A, B, C, x, phi0, f) == f
@test f ≈ fd
end

Expand Down
Loading