Skip to content

Commit f89dbb3

Browse files
authored
Generalise Clenshaw (#112)
* Generalise clenshaw for other array types * Support general Clenshaw * Add forwardrecurrence! * Turn on codecov * fix tests * Match libfasttransforms in Clenshaw * Add ChebyshevU special case * use propogate_inbounds
1 parent 8d6e163 commit f89dbb3

File tree

8 files changed

+297
-22
lines changed

8 files changed

+297
-22
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ deps/build.log
44
deps/libfasttransforms.*
55
.DS_Store
66
deps/FastTransforms/
7+
Manifest.toml

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
99
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1010
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1111
FastTransforms_jll = "34b6f7d7-08f9-5794-9e10-3819e4c7e49a"
12+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1213
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -23,6 +24,7 @@ DSP = "0.6"
2324
FFTW = "1"
2425
FastGaussQuadrature = "0.4"
2526
FastTransforms_jll = "0.3.2"
27+
FillArrays = "0.8"
2628
Reexport = "0.2"
2729
SpecialFunctions = "0.8, 0.9, 0.10"
2830
ToeplitzMatrices = "0.6"

src/FastTransforms.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module FastTransforms
22

33
using FastGaussQuadrature, LinearAlgebra
4-
using Reexport, SpecialFunctions, ToeplitzMatrices
4+
using Reexport, SpecialFunctions, ToeplitzMatrices, FillArrays
55

66
import DSP
77

@@ -28,6 +28,8 @@ import FFTW: dct, dct!, idct, idct!, plan_dct!, plan_idct!,
2828

2929
import FastGaussQuadrature: unweightedgausshermite
3030

31+
import FillArrays: AbstractFill, getindex_value
32+
3133
import LinearAlgebra: mul!, lmul!, ldiv!
3234

3335
export leg2cheb, cheb2leg, ultra2ultra, jac2jac,
@@ -94,4 +96,6 @@ lgamma(x) = logabsgamma(x)[1]
9496

9597
include("specialfunctions.jl")
9698

99+
include("clenshaw.jl")
100+
97101
end # module

src/clenshaw.jl

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
forwardrecurrence!(v, A, B, C, x)
3+
4+
evaluates the orthogonal polynomials at points `x`,
5+
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
6+
as defined in DLMF,
7+
overwriting `v` with the results.
8+
"""
9+
function forwardrecurrence!(v::AbstractVector{T}, A::AbstractVector, B::AbstractVector, C::AbstractVector, x) where T
10+
N = length(v)
11+
N == 0 && return v
12+
length(A)+1 N && length(B)+1 N && length(C)+1 N || throw(ArgumentError("A, B, C must contain at least $(N-1) entries"))
13+
p0 = one(T) # assume OPs are normalized to one for no
14+
p1 = convert(T, N == 1 ? p0 : A[1]x + B[1]) # avoid accessing A[1]/B[1] if empty
15+
_forwardrecurrence!(v, A, B, C, x, p0, p1)
16+
end
17+
18+
19+
Base.@propagate_inbounds _forwardrecurrence_next(n, A, B, C, x, p0, p1) = muladd(muladd(A[n],x,B[n]), p1, -C[n]*p0)
20+
# special case for B[n] == 0
21+
Base.@propagate_inbounds _forwardrecurrence_next(n, A, ::Zeros, C, x, p0, p1) = muladd(A[n]*x, p1, -C[n]*p0)
22+
# special case for Chebyshev U
23+
Base.@propagate_inbounds _forwardrecurrence_next(n, A::AbstractFill, ::Zeros, C::Ones, x, p0, p1) = muladd(getindex_value(A)*x, p1, -p0)
24+
25+
26+
# this supports adaptivity: we can populate `v` for large `n`
27+
function _forwardrecurrence!(v::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x, p0, p1)
28+
N = length(v)
29+
N == 0 && return v
30+
v[1] = p0
31+
N == 1 && return v
32+
v[2] = p1
33+
@inbounds for n = 2:N-1
34+
p1,p0 = _forwardrecurrence_next(n, A, B, C, x, p0, p1),p1
35+
v[n+1] = p1
36+
end
37+
v
38+
end
39+
40+
41+
42+
forwardrecurrence(N::Integer, A::AbstractVector, B::AbstractVector, C::AbstractVector, x) =
43+
forwardrecurrence!(Vector{promote_type(eltype(A),eltype(B),eltype(C),typeof(x))}(undef, N), A, B, C, x)
44+
45+
46+
"""
47+
clenshaw!(c, A, B, C, x)
48+
49+
evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
50+
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
51+
as defined in DLMF,
52+
overwriting `x` with the results.
53+
"""
54+
clenshaw!(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector) =
55+
clenshaw!(c, A, B, C, x, Ones{eltype(x)}(length(x)), x)
56+
57+
58+
"""
59+
clenshaw!(c, A, B, C, x, ϕ₀, f)
60+
61+
evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
62+
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
63+
as defined in DLMF and ϕ₀ is the zeroth coefficient,
64+
overwriting `f` with the results.
65+
"""
66+
function clenshaw!(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector, ϕ₀::AbstractVector, f::AbstractVector)
67+
f .= ϕ₀ .* clenshaw.(Ref(c), Ref(A), Ref(B), Ref(C), x)
68+
end
69+
70+
71+
@inline _clenshaw_next(n, A, B, C, x, c, bn1, bn2) = muladd(muladd(A[n],x,B[n]), bn1, muladd(-C[n+1],bn2,c[n]))
72+
@inline _clenshaw_next(n, A, ::Zeros, C, x, c, bn1, bn2) = muladd(A[n]*x, bn1, muladd(-C[n+1],bn2,c[n]))
73+
# Chebyshev U
74+
@inline _clenshaw_next(n, A::AbstractFill, ::Zeros, C::Ones, x, c, bn1, bn2) = muladd(getindex_value(A)*x, bn1, -bn2+c[n])
75+
76+
"""
77+
clenshaw(c, A, B, C, x)
78+
79+
evaluates the orthogonal polynomial expansion with coefficients `c` at points `x`,
80+
where `A`, `B`, and `C` are `AbstractVector`s containing the recurrence coefficients
81+
as defined in DLMF.
82+
`x` may also be a single `Number`.
83+
"""
84+
85+
function clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::Number)
86+
N = length(c)
87+
T = promote_type(eltype(c),eltype(A),eltype(B),eltype(C),typeof(x))
88+
@boundscheck check_clenshaw_recurrences(N, A, B, C)
89+
N == 0 && return zero(T)
90+
@inbounds begin
91+
bn2 = zero(T)
92+
bn1 = convert(T,c[N])
93+
for n = N-1:-1:1
94+
bn1,bn2 = _clenshaw_next(n, A, B, C, x, c, bn1, bn2),bn1
95+
end
96+
end
97+
bn1
98+
end
99+
100+
101+
clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractVector) =
102+
clenshaw!(c, A, B, C, copy(x))
103+
104+
###
105+
# Chebyshev T special cases
106+
###
107+
108+
"""
109+
clenshaw!(c, x)
110+
111+
evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at points `x`,
112+
overwriting `x` with the results.
113+
"""
114+
clenshaw!(c::AbstractVector, x::AbstractVector) = clenshaw!(c, x, x)
115+
116+
117+
"""
118+
clenshaw!(c, x, f)
119+
120+
evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at points `x`,
121+
overwriting `f` with the results.
122+
"""
123+
function clenshaw!(c::AbstractVector, x::AbstractVector, f::AbstractVector)
124+
f .= clenshaw.(Ref(c), x)
125+
end
126+
127+
"""
128+
clenshaw(c, x)
129+
130+
evaluates the first-kind Chebyshev (T) expansion with coefficients `c` at the points `x`.
131+
`x` may also be a single `Number`.
132+
"""
133+
function clenshaw(c::AbstractVector, x::Number)
134+
N,T = length(c),promote_type(eltype(c),typeof(x))
135+
if N == 0
136+
return zero(T)
137+
elseif N == 1 # avoid issues with NaN x
138+
return first(c)*one(x)
139+
end
140+
141+
y = 2x
142+
bk1,bk2 = zero(T),zero(T)
143+
@inbounds begin
144+
for k = N:-1:2
145+
bk1,bk2 = muladd(y,bk1,c[k]-bk2),bk1
146+
end
147+
muladd(x,bk1,c[1]-bk2)
148+
end
149+
end
150+
151+
clenshaw(c::AbstractVector, x::AbstractVector) = clenshaw!(c, copy(x))
152+

src/libfasttransforms.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,51 @@ set_num_threads(n::Integer) = ccall((:ft_set_num_threads, libfasttransforms), Cv
5454
function horner!(c::Vector{Float64}, x::Vector{Float64}, f::Vector{Float64})
5555
@assert length(x) == length(f)
5656
ccall((:ft_horner, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, length(x), x, f)
57+
f
5758
end
5859

5960
function horner!(c::Vector{Float32}, x::Vector{Float32}, f::Vector{Float32})
6061
@assert length(x) == length(f)
6162
ccall((:ft_hornerf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Cint, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, length(x), x, f)
63+
f
64+
end
65+
66+
function check_clenshaw_recurrences(N, A, B, C)
67+
if length(A) < N || length(B) < N || length(C) < N+1
68+
throw(ArgumentError("A, B must contain at least $N entries and C must contain at least $(N+1) entrie"))
69+
end
70+
end
71+
72+
function check_clenshaw_points(x, ϕ₀, f)
73+
length(x) == length(ϕ₀) == length(f) || throw(ArgumentError("Dimensions must match"))
6274
end
6375

6476
function clenshaw!(c::Vector{Float64}, x::Vector{Float64}, f::Vector{Float64})
6577
@assert length(x) == length(f)
6678
ccall((:ft_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Cint, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, length(x), x, f)
79+
f
6780
end
6881

6982
function clenshaw!(c::Vector{Float32}, x::Vector{Float32}, f::Vector{Float32})
7083
@assert length(x) == length(f)
7184
ccall((:ft_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Cint, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, length(x), x, f)
85+
f
7286
end
7387

74-
function clenshaw!(c::Vector{Float64}, A::Vector{Float64}, B::Vector{Float64}, C::Vector{Float64}, x::Vector{Float64}, phi0::Vector{Float64}, f::Vector{Float64})
75-
@assert length(c) == length(A) == length(B) == length(C)-1
76-
@assert length(x) == length(phi0) == length(f)
77-
ccall((:ft_orthogonal_polynomial_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), length(c), c, 1, A, B, C, length(x), x, phi0, f)
88+
function clenshaw!(c::Vector{Float64}, A::Vector{Float64}, B::Vector{Float64}, C::Vector{Float64}, x::Vector{Float64}, ϕ₀::Vector{Float64}, f::Vector{Float64})
89+
N = length(c)
90+
@boundscheck check_clenshaw_recurrences(N, A, B, C)
91+
@boundscheck check_clenshaw_points(x, ϕ₀, f)
92+
ccall((:ft_orthogonal_polynomial_clenshaw, libfasttransforms), Cvoid, (Cint, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Cint, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}), N, c, 1, A, B, C, length(x), x, ϕ₀, f)
93+
f
7894
end
7995

80-
function clenshaw!(c::Vector{Float32}, A::Vector{Float32}, B::Vector{Float32}, C::Vector{Float32}, x::Vector{Float32}, phi0::Vector{Float32}, f::Vector{Float32})
81-
@assert length(c) == length(A) == length(B) == length(C)-1
82-
@assert length(x) == length(phi0) == length(f)
83-
ccall((:ft_orthogonal_polynomial_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}), length(c), c, 1, A, B, C, length(x), x, phi0, f)
96+
function clenshaw!(c::Vector{Float32}, A::Vector{Float32}, B::Vector{Float32}, C::Vector{Float32}, x::Vector{Float32}, ϕ₀::Vector{Float32}, f::Vector{Float32})
97+
N = length(c)
98+
@boundscheck check_clenshaw_recurrences(N, A, B, C)
99+
@boundscheck check_clenshaw_points(x, ϕ₀, f)
100+
ccall((:ft_orthogonal_polynomial_clenshawf, libfasttransforms), Cvoid, (Cint, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}, Cint, Ptr{Float32}, Ptr{Float32}, Ptr{Float32}), N, c, 1, A, B, C, length(x), x, ϕ₀, f)
101+
f
84102
end
85103

86104
const LEG2CHEB = 0

test/clenshawtests.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using FastTransforms, FillArrays, Test
2+
import FastTransforms: clenshaw, clenshaw!, forwardrecurrence!, forwardrecurrence
3+
4+
@testset "clenshaw" begin
5+
@testset "Chebyshev T" begin
6+
c = [1,2,3]
7+
cf = float(c)
8+
@test @inferred(clenshaw(c,1)) 1 + 2 + 3
9+
@test @inferred(clenshaw(c,0)) 1 + 0 - 3
10+
@test @inferred(clenshaw(c,0.1)) == 1 + 2*0.1 + 3*cos(2acos(0.1))
11+
@test @inferred(clenshaw(c,[-1,0,1])) == clenshaw!(c,[-1,0,1]) == [2,-2,6]
12+
@test clenshaw(c,[-1,0,1]) isa Vector{Int}
13+
@test @inferred(clenshaw(Float64[],1)) 0.0
14+
15+
x = [1,0,0.1]
16+
@test @inferred(clenshaw(c,x)) @inferred(clenshaw!(c,copy(x)))
17+
@inferred(clenshaw!(c,x,similar(x)))
18+
@inferred(clenshaw(cf,x)) @inferred(clenshaw!(cf,copy(x)))
19+
@inferred(clenshaw!(cf,x,similar(x))) [6,-2,-1.74]
20+
end
21+
22+
@testset "Chebyshev U" begin
23+
N = 5
24+
A, B, C = Fill(2,N-1), Zeros{Int}(N-1), Ones{Int}(N)
25+
@testset "forwardrecurrence!" begin
26+
@test @inferred(forwardrecurrence(N, A, B, C, 1)) == @inferred(forwardrecurrence!(Vector{Int}(undef,N), A, B, C, 1)) == 1:N
27+
@test forwardrecurrence!(Vector{Int}(undef,N), A, B, C, -1) == (-1) .^ (0:N-1) .* (1:N)
28+
@test forwardrecurrence(N, A, B, C, 0.1) forwardrecurrence!(Vector{Float64}(undef,N), A, B, C, 0.1)
29+
sin.((1:N) .* acos(0.1)) ./ sqrt(1-0.1^2)
30+
end
31+
32+
c = [1,2,3]
33+
@test c'forwardrecurrence(3, A, B, C, 0.1) clenshaw([1,2,3], A, B, C, 0.1)
34+
1 + (2sin(2acos(0.1)) + 3sin(3acos(0.1)))/sqrt(1-0.1^2)
35+
end
36+
37+
@testset "Chebyshev-as-general" begin
38+
@testset "forwardrecurrence!" begin
39+
N = 5
40+
A, B, C = [1; fill(2,N-2)], fill(0,N-1), fill(1,N)
41+
Af, Bf, Cf = float(A), float(B), float(C)
42+
@test forwardrecurrence(N, A, B, C, 1) == forwardrecurrence!(Vector{Int}(undef,N), A, B, C, 1) == ones(Int,N)
43+
@test forwardrecurrence!(Vector{Int}(undef,N), A, B, C, -1) == (-1) .^ (0:N-1)
44+
@test forwardrecurrence(N, A, B, C, 0.1) forwardrecurrence!(Vector{Float64}(undef,N), A, B, C, 0.1) cos.((0:N-1) .* acos(0.1))
45+
end
46+
47+
c, A, B, C = [1,2,3], [1,2,2], fill(0,3), fill(1,4)
48+
cf, Af, Bf, Cf = float(c), float(A), float(B), float(C)
49+
@test @inferred(clenshaw(c, A, B, C, 1)) 6
50+
@test @inferred(clenshaw(c, A, B, C, 0.1)) -1.74
51+
@test @inferred(clenshaw([1,2,3], A, B, C, [-1,0,1])) == clenshaw!([1,2,3],A, B, C, [-1,0,1]) == [2,-2,6]
52+
@test clenshaw(c, A, B, C, [-1,0,1]) isa Vector{Int}
53+
@test @inferred(clenshaw(Float64[], A, B, C, 1)) 0.0
54+
55+
x = [1,0,0.1]
56+
@test @inferred(clenshaw(c, A, B, C, x)) @inferred(clenshaw!(c, A, B, C, copy(x)))
57+
@inferred(clenshaw!(c, A, B, C, x, one.(x), similar(x)))
58+
@inferred(clenshaw!(cf, Af, Bf, Cf, x, one.(x),similar(x)))
59+
@inferred(clenshaw([1.,2,3], A, B, C, x))
60+
@inferred(clenshaw!([1.,2,3], A, B, C, copy(x))) [6,-2,-1.74]
61+
end
62+
63+
@testset "Legendre" begin
64+
@testset "Float64" begin
65+
N = 5
66+
n = 0:N-1
67+
A = (2n .+ 1) ./ (n .+ 1)
68+
B = zeros(N)
69+
C = n ./ (n .+ 1)
70+
v_1 = forwardrecurrence(N, A, B, C, 1)
71+
v_f = forwardrecurrence(N, A, B, C, 0.1)
72+
@test v_1 ones(N)
73+
@test forwardrecurrence(N, A, B, C, -1) (-1) .^ (0:N-1)
74+
@test v_f [1,0.1,-0.485,-0.1475,0.3379375]
75+
76+
n = 0:N # need extra entry for C in Clenshaw
77+
C = n ./ (n .+ 1)
78+
for j = 1:N
79+
c = [zeros(j-1); 1]
80+
@test clenshaw(c, A, B, C, 1) v_1[j] # Julia code
81+
@test clenshaw(c, A, B, C, 0.1) v_f[j] # Julia code
82+
@test clenshaw!(c, A, B, C, [1.0,0.1], [1.0,1.0], [0.0,0.0]) [v_1[j],v_f[j]] # libfasttransforms
83+
end
84+
end
85+
86+
@testset "BigFloat" begin
87+
N = 5
88+
n = BigFloat(0):N-1
89+
A = (2n .+ 1) ./ (n .+ 1)
90+
B = zeros(N)
91+
C = n ./ (n .+ 1)
92+
@test forwardrecurrence(N, A, B, C, parse(BigFloat,"0.1")) [1,big"0.1",big"-0.485",big"-0.1475",big"0.3379375"]
93+
end
94+
end
95+
96+
@testset "Int" begin
97+
N = 10; A = 1:10; B = 2:11; C = range(3; step=2, length=N+1)
98+
v_i = forwardrecurrence(N, A, B, C, 1)
99+
v_f = forwardrecurrence(N, A, B, C, 0.1)
100+
@test v_i isa Vector{Int}
101+
@test v_f isa Vector{Float64}
102+
103+
j = 3
104+
clenshaw([zeros(Int,j-1); 1; zeros(Int,N-j)], A, B, C, 1) == v_i[j]
105+
end
106+
end

test/libfasttransformstests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@ FastTransforms.set_num_threads(ceil(Int, Base.Sys.CPU_THREADS/2))
77
for T in (Float32, Float64)
88
c = one(T) ./ (1:n)
99
x = collect(-1 .+ 2*(0:n-1)/T(n))
10-
f = zero(x)
11-
FastTransforms.horner!(c, x, f)
10+
f = similar(x)
11+
@test FastTransforms.horner!(c, x, f) == f
1212
fd = T[sum(c[k]*x^(k-1) for k in 1:length(c)) for x in x]
1313
@test f fd
14-
FastTransforms.clenshaw!(c, x, f)
14+
@test FastTransforms.clenshaw!(c, x, f) == f
1515
fd = T[sum(c[k]*cos((k-1)*acos(x)) for k in 1:length(c)) for x in x]
1616
@test f fd
1717
A = T[(2k+one(T))/(k+one(T)) for k in 0:length(c)-1]
1818
B = T[zero(T) for k in 0:length(c)-1]
1919
C = T[k/(k+one(T)) for k in 0:length(c)]
2020
phi0 = ones(T, length(x))
2121
c = cheb2leg(c)
22-
FastTransforms.clenshaw!(c, A, B, C, x, phi0, f)
22+
@test FastTransforms.clenshaw!(c, A, B, C, x, phi0, f) == f
2323
@test f fd
2424
end
2525

0 commit comments

Comments
 (0)