Skip to content

Commit a7a10d2

Browse files
committed
fix itransforms
1 parent 919fad8 commit a7a10d2

File tree

2 files changed

+101
-61
lines changed

2 files changed

+101
-61
lines changed

src/chebyshevtransform.jl

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ plan_chebyshevtransform(x::AbstractArray, dims...; kws...) = plan_chebyshevtrans
5151

5252

5353
# convert x if necessary
54-
_plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::StridedArray{T}) where T = mul!(y, P, x)
55-
_plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x))
54+
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::StridedArray{T}) where T = mul!(y, P, x)
55+
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x))
5656

57-
_cheb1_rescale!(_, y::AbstractVector) = (y[1] /= 2; ldiv!(length(y), y))
57+
@inline _cheb1_rescale!(_, y::AbstractVector) = (y[1] /= 2; ldiv!(length(y), y))
5858

59-
function _cheb1_rescale!(d::Number, y::AbstractMatrix{T}) where T
59+
@inline function _cheb1_rescale!(d::Number, y::AbstractMatrix{T}) where T
6060
if isone(d)
6161
ldiv!(2, view(y,1,:))
6262
else
@@ -66,22 +66,22 @@ function _cheb1_rescale!(d::Number, y::AbstractMatrix{T}) where T
6666
end
6767

6868
# TODO: higher dimensional arrays
69-
function _cheb1_rescale!(d::UnitRange, y::AbstractMatrix{T}) where T
69+
@inline function _cheb1_rescale!(d::UnitRange, y::AbstractMatrix{T}) where T
7070
@assert d == 1:2
7171
ldiv!(2, view(y,1,:))
7272
ldiv!(2, view(y,:,1))
7373
ldiv!(prod(size(y)), y)
7474
end
7575

76-
function *(P::ChebyshevTransformPlan{T,1,K,true}, x::AbstractArray{T}) where {T,K}
76+
function *(P::ChebyshevTransformPlan{T,1,K,true,N}, x::AbstractArray{T,N}) where {T,K,N}
7777
n = length(x)
7878
n == 0 && return x
7979

8080
y = P.plan*x # will be === x if in-place
8181
_cheb1_rescale!(P.plan.region, y)
8282
end
8383

84-
function mul!(y::AbstractArray{T}, P::ChebyshevTransformPlan{T,1,K,false}, x::AbstractArray) where {T,K}
84+
function mul!(y::AbstractArray{T,N}, P::ChebyshevTransformPlan{T,1,K,false,N}, x::AbstractArray{<:Any,N}) where {T,K,N}
8585
n = length(x)
8686
length(y) == n || throw(DimensionMismatch("output must match dimension"))
8787
n == 0 && return y
@@ -113,20 +113,20 @@ function _cheb2_rescale!(d::UnitRange, y::AbstractMatrix{T}) where T
113113
ldiv!(prod(size(y) .- 1), y)
114114
end
115115

116-
function *(P::ChebyshevTransformPlan{T,2,K,true}, x::AbstractArray{T}) where {T,K}
116+
function *(P::ChebyshevTransformPlan{T,2,K,true,N}, x::AbstractArray{T,N}) where {T,K,N}
117117
n = length(x)
118118
y = P.plan*x # will be === x if in-place
119119
_cheb2_rescale!(P.plan.region, y)
120120
end
121121

122-
function mul!(y::AbstractArray{T}, P::ChebyshevTransformPlan{T,2,K,false}, x::AbstractArray) where {T,K}
122+
function mul!(y::AbstractArray{T,N}, P::ChebyshevTransformPlan{T,2,K,false,N}, x::AbstractArray{<:Any,N}) where {T,K,N}
123123
n = length(x)
124124
length(y) == n || throw(DimensionMismatch("output must match dimension"))
125125
_plan_mul!(y, P.plan, x)
126126
_cheb2_rescale!(P.plan.region, y)
127127
end
128128

129-
*(P::ChebyshevTransformPlan{T,kind,K,false}, x::AbstractArray{T}) where {T,kind,K} =
129+
*(P::ChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} =
130130
mul!(similar(x), P, x)
131131

132132
"""
@@ -202,18 +202,37 @@ end
202202
plan_ichebyshevtransform!(x::AbstractArray, dims...; kws...) = plan_ichebyshevtransform!(x, Val(1), dims...; kws...)
203203
plan_ichebyshevtransform(x::AbstractArray, dims...; kws...) = plan_ichebyshevtransform(x, Val(1), dims...; kws...)
204204

205-
_icheb1_prerescale!(_, x::AbstractVector) = (x[1] *= 2)
206-
_icheb1_postrescale!(_, x::AbstractVector) = (x[1] /= 2)
207-
function _icheb1_prerescale!(d::Number, x::AbstractVector)
208-
lmul!(2, isone(d) ? view(x,:,1) : view(x,1,:))
205+
@inline _icheb1_prerescale!(_, x::AbstractVector) = (x[1] *= 2)
206+
@inline function _icheb1_prerescale!(d::Number, x::AbstractMatrix)
207+
if isone(d)
208+
lmul!(2, view(x,1,:))
209+
else
210+
lmul!(2, view(x,:,1))
211+
end
212+
x
213+
end
214+
@inline function _icheb1_prerescale!(d::UnitRange, x::AbstractMatrix)
215+
lmul!(2, view(x,:,1))
216+
lmul!(2, view(x,1,:))
217+
x
218+
end
219+
@inline _icheb1_postrescale!(_, x::AbstractVector) = (x[1] /= 2)
220+
@inline function _icheb1_postrescale!(d::Number, x::AbstractMatrix)
221+
if isone(d)
222+
ldiv!(2, view(x,1,:))
223+
else
224+
ldiv!(2, view(x,:,1))
225+
end
209226
x
210227
end
211-
function _icheb1_postrescale!(_, x::AbstractVector)
212-
ldiv(2, isone(d) ? view(x,:,1) : view(x,1,:))
228+
229+
@inline function _icheb1_postrescale!(d::UnitRange, x::AbstractMatrix)
230+
ldiv!(2, view(x,1,:))
231+
ldiv!(2, view(x,:,1))
213232
x
214233
end
215234

216-
function *(P::IChebyshevTransformPlan{T,1,K,true}, x::AbstractVector{T}) where {T<:fftwNumber,K}
235+
function *(P::IChebyshevTransformPlan{T,1,K,true,N}, x::AbstractArray{T,N}) where {T<:fftwNumber,K,N}
217236
n = length(x)
218237
n == 0 && return x
219238

@@ -222,7 +241,7 @@ function *(P::IChebyshevTransformPlan{T,1,K,true}, x::AbstractVector{T}) where {
222241
x
223242
end
224243

225-
function mul!(y::AbstractVector{T}, P::IChebyshevTransformPlan{T,1,K,false}, x::AbstractVector{T}) where {T<:fftwNumber,K}
244+
function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,1,K,false,N}, x::AbstractArray{T,N}) where {T<:fftwNumber,K,N}
226245
n = length(x)
227246
length(y) == n || throw(DimensionMismatch("output must match dimension"))
228247
n == 0 && return y
@@ -233,31 +252,62 @@ function mul!(y::AbstractVector{T}, P::IChebyshevTransformPlan{T,1,K,false}, x::
233252
ldiv!(2^length(P.plan.region), y)
234253
end
235254

236-
_icheb2_prerescale!(_, x::AbstractVector) = (x[1] *= 2; x[end] *= 2)
237-
_icheb2_postrescale!(_, x::AbstractVector) = (x[1] /= 2; x[end] /= 2)
238-
function _icheb2_rescale!(d, y::AbstractVector)
239-
_icheb2_prerescale!(d, y)
240-
lmul!(convert(T, prod(size(y) .- 1))/2, y)
241-
y
255+
@inline _icheb2_prerescale!(_, x::AbstractVector) = (x[1] *= 2; x[end] *= 2)
256+
@inline function _icheb2_prerescale!(d::Number, x::AbstractMatrix)
257+
if isone(d)
258+
lmul!(2, @view(x[1,:]))
259+
lmul!(2, @view(x[end,:]))
260+
else
261+
lmul!(2, @view(x[:,1]))
262+
lmul!(2, @view(x[:,end]))
263+
end
264+
x
242265
end
243-
function _icheb2_prerescale!(d::Number, x::AbstractVector)
244-
lmul!(2, isone(d) ? view(x,:,1) : view(x,1,:))
266+
@inline function _icheb2_prerescale!(d::UnitRange, x::AbstractMatrix)
267+
lmul!(2, @view(x[1,:]))
268+
lmul!(2, @view(x[end,:]))
269+
lmul!(2, @view(x[:,1]))
270+
lmul!(2, @view(x[:,end]))
271+
x
272+
end
273+
@inline _icheb2_postrescale!(_, x::AbstractVector) = (x[1] /= 2; x[end] /= 2)
274+
@inline function _icheb2_postrescale!(d::Number, x::AbstractMatrix)
275+
if isone(d)
276+
ldiv!(2, @view(x[1,:]))
277+
ldiv!(2, @view(x[end,:]))
278+
else
279+
ldiv!(2, @view(x[:,1]))
280+
ldiv!(2, @view(x[:,end]))
281+
end
245282
x
246283
end
247-
function _icheb2_postrescale!(_, x::AbstractVector)
248-
ldiv(2, isone(d) ? view(x,:,1) : view(x,1,:))
284+
@inline function _icheb2_postrescale!(d::UnitRange, x::AbstractMatrix)
285+
ldiv!(2, @view(x[1,:]))
286+
ldiv!(2, @view(x[end,:]))
287+
ldiv!(2, @view(x[:,1]))
288+
ldiv!(2, @view(x[:,end]))
249289
x
250290
end
291+
@inline function _icheb2_rescale!(d::Number, y::AbstractArray{T}) where T
292+
_icheb2_prerescale!(d, y)
293+
lmul!(convert(T, size(y,d) - 1)/2, y)
294+
y
295+
end
296+
@inline function _icheb2_rescale!(d::UnitRange, y::AbstractArray{T}) where T
297+
_icheb2_prerescale!(d, y)
298+
lmul!(prod(convert.(T, size(y) .- 1)./2), y)
299+
y
300+
end
251301

252-
function *(P::IChebyshevTransformPlan{T,2,K, true}, x::AbstractVector{T}) where {T<:fftwNumber,K}
302+
function *(P::IChebyshevTransformPlan{T,2,K,true,N}, x::AbstractArray{T,N}) where {T<:fftwNumber,K,N}
253303
n = length(x)
254304

255305
_icheb2_prerescale!(P.plan.region, x)
256306
x = inv(P)*x
257307
_icheb2_rescale!(P.plan.region, x)
258308
end
259309

260-
function mul!(y::AbstractVector{T}, P::IChebyshevTransformPlan{T,2,K,false}, x::AbstractVector{T}) where {T<:fftwNumber,K}
310+
function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,2,K,false,N}, x::AbstractArray{<:Any,N}) where {T<:fftwNumber,K,N}
261311
n = length(x)
262312
length(y) == n || throw(DimensionMismatch("output must match dimension"))
263313

@@ -267,7 +317,7 @@ function mul!(y::AbstractVector{T}, P::IChebyshevTransformPlan{T,2,K,false}, x::
267317
_icheb2_rescale!(P.plan.region, y)
268318
end
269319

270-
*(P::IChebyshevTransformPlan{T,kind,K,false},x::AbstractVector{T}) where {T,kind,K} = mul!(similar(x), P, convert(Array,x))
320+
*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} = mul!(similar(x), P, x)
271321
ichebyshevtransform!(x::AbstractArray, dims...; kwds...) = plan_ichebyshevtransform!(x, dims...; kwds...)*x
272322
ichebyshevtransform(x, dims...; kwds...) = plan_ichebyshevtransform(x, dims...; kwds...)*x
273323

test/chebyshevtests.jl

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -174,43 +174,33 @@ using FastTransforms, Test
174174

175175
@testset "matrix" begin
176176
X = randn(4,5)
177-
@test @inferred(chebyshevtransform(X,1)) @inferred(chebyshevtransform!(copy(X),1)) hcat(chebyshevtransform.([X[:,k] for k=axes(X,2)])...)
178-
@test chebyshevtransform(X,2) chebyshevtransform!(copy(X),2) hcat(chebyshevtransform.([X[k,:] for k=axes(X,1)])...)'
179-
@test @inferred(chebyshevtransform(X,Val(2),1)) @inferred(chebyshevtransform!(copy(X),Val(2),1)) hcat(chebyshevtransform.([X[:,k] for k=axes(X,2)],Val(2))...)
180-
@test chebyshevtransform(X,Val(2),2) chebyshevtransform!(copy(X),Val(2),2) hcat(chebyshevtransform.([X[k,:] for k=axes(X,1)],Val(2))...)'
177+
@testset "chebyshevtransform" begin
178+
@test @inferred(chebyshevtransform(X,1)) @inferred(chebyshevtransform!(copy(X),1)) hcat(chebyshevtransform.([X[:,k] for k=axes(X,2)])...)
179+
@test chebyshevtransform(X,2) chebyshevtransform!(copy(X),2) hcat(chebyshevtransform.([X[k,:] for k=axes(X,1)])...)'
180+
@test @inferred(chebyshevtransform(X,Val(2),1)) @inferred(chebyshevtransform!(copy(X),Val(2),1)) hcat(chebyshevtransform.([X[:,k] for k=axes(X,2)],Val(2))...)
181+
@test chebyshevtransform(X,Val(2),2) chebyshevtransform!(copy(X),Val(2),2) hcat(chebyshevtransform.([X[k,:] for k=axes(X,1)],Val(2))...)'
182+
183+
@test @inferred(chebyshevtransform(X)) @inferred(chebyshevtransform!(copy(X))) chebyshevtransform(chebyshevtransform(X,1),2)
184+
@test @inferred(chebyshevtransform(X,Val(2))) @inferred(chebyshevtransform!(copy(X),Val(2))) chebyshevtransform(chebyshevtransform(X,Val(2),1),Val(2),2)
185+
end
186+
187+
@testset "ichebyshevtransform" begin
188+
@test @inferred(ichebyshevtransform(X,1)) @inferred(ichebyshevtransform!(copy(X),1)) hcat(ichebyshevtransform.([X[:,k] for k=axes(X,2)])...)
189+
@test ichebyshevtransform(X,2) ichebyshevtransform!(copy(X),2) hcat(ichebyshevtransform.([X[k,:] for k=axes(X,1)])...)'
190+
@test @inferred(ichebyshevtransform(X,Val(2),1)) @inferred(ichebyshevtransform!(copy(X),Val(2),1)) hcat(ichebyshevtransform.([X[:,k] for k=axes(X,2)],Val(2))...)
191+
@test ichebyshevtransform(X,Val(2),2) ichebyshevtransform!(copy(X),Val(2),2) hcat(ichebyshevtransform.([X[k,:] for k=axes(X,1)],Val(2))...)'
181192

182-
@test @inferred(chebyshevtransform(X)) @inferred(chebyshevtransform!(copy(X))) chebyshevtransform(chebyshevtransform(X,1),2)
183-
@test @inferred(chebyshevtransform(X,Val(2))) @inferred(chebyshevtransform!(copy(X),Val(2))) chebyshevtransform(chebyshevtransform(X,Val(2),1),Val(2),2)
193+
@test @inferred(ichebyshevtransform(X)) @inferred(ichebyshevtransform!(copy(X))) ichebyshevtransform(ichebyshevtransform(X,1),2)
194+
@test @inferred(ichebyshevtransform(X,Val(2))) @inferred(ichebyshevtransform!(copy(X),Val(2))) ichebyshevtransform(ichebyshevtransform(X,Val(2),1),Val(2),2)
184195

196+
@test ichebyshevtransform(chebyshevtransform(X)) X
197+
@test chebyshevtransform(ichebyshevtransform(X)) X
198+
end
185199

186200
X = randn(1,1)
187201
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
188202
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
189203
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))
190-
191-
X = randn(10,11)
192-
193-
# manual 2D Chebyshev
194-
= copy(X)
195-
for j in axes(X̌,2)
196-
chebyshevtransform!(view(X̌,:,j))
197-
end
198-
for k in axes(X̌,1)
199-
chebyshevtransform!(view(X̌,k,:))
200-
end
201-
@test chebyshevtransform!(copy(X), Val(1))
202-
@test ichebyshevtransform!(copy(X̌), Val(1)) X
203-
204-
# manual 2D Chebyshev
205-
= copy(X)
206-
for j in axes(X̌,2)
207-
chebyshevtransform!(view(X̌,:,j), Val(2))
208-
end
209-
for k in axes(X̌,1)
210-
chebyshevtransform!(view(X̌,k,:), Val(2))
211-
end
212-
@test chebyshevtransform!(copy(X), Val(2))
213-
@test ichebyshevtransform!(copy(X̌), Val(2)) X
214204
end
215205

216206
@testset "Integer" begin

0 commit comments

Comments
 (0)