Skip to content

Commit 51688b7

Browse files
authored
Support 3D tensors (#185)
* Start supporting 3D tensors * Tensor Chebyshev 2 * ichebyshevtransform
1 parent 6e8fbb6 commit 51688b7

File tree

3 files changed

+176
-76
lines changed

3 files changed

+176
-76
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FastTransforms"
22
uuid = "057dd010-8810-581a-b7be-e3fc3b93f78c"
3-
version = "0.14.3"
3+
version = "0.14.4"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/chebyshevtransform.jl

Lines changed: 125 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,100 @@ plan_chebyshevtransform(x::AbstractArray, dims...; kws...) = plan_chebyshevtrans
5454
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::StridedArray{T}) where T = mul!(y, P, x)
5555
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x))
5656

57-
@inline _cheb1_rescale!(_, y::AbstractVector) = (y[1] /= 2; ldiv!(length(y), y))
5857

59-
@inline function _cheb1_rescale!(d::Number, y::AbstractMatrix{T}) where T
58+
59+
ldiv_dim_begin!(α, d::Number, y::AbstractVector) = y[1] /= α
60+
function ldiv_dim_begin!(α, d::Number, y::AbstractMatrix)
61+
if isone(d)
62+
ldiv!(α, @view(y[1,:]))
63+
else
64+
ldiv!(α, @view(y[:,1]))
65+
end
66+
end
67+
function ldiv_dim_begin!(α, d::Number, y::AbstractArray{<:Any,3})
68+
if isone(d)
69+
ldiv!(α, @view(y[1,:,:]))
70+
elseif d == 2
71+
ldiv!(α, @view(y[:,1,:]))
72+
else # d == 3
73+
ldiv!(α, @view(y[:,:,1]))
74+
end
75+
end
76+
77+
ldiv_dim_end!(α, d::Number, y::AbstractVector) = y[end] /= α
78+
function ldiv_dim_end!(α, d::Number, y::AbstractMatrix)
79+
if isone(d)
80+
ldiv!(α, @view(y[end,:]))
81+
else
82+
ldiv!(α, @view(y[:,end]))
83+
end
84+
end
85+
function ldiv_dim_end!(α, d::Number, y::AbstractArray{<:Any,3})
86+
if isone(d)
87+
ldiv!(α, @view(y[end,:,:]))
88+
elseif d == 2
89+
ldiv!(α, @view(y[:,end,:]))
90+
else # d == 3
91+
ldiv!(α, @view(y[:,:,end]))
92+
end
93+
end
94+
95+
lmul_dim_begin!(α, d::Number, y::AbstractVector) = y[1] *= α
96+
function lmul_dim_begin!(α, d::Number, y::AbstractMatrix)
6097
if isone(d)
61-
ldiv!(2, view(y,1,:))
98+
lmul!(α, @view(y[1,:]))
6299
else
63-
ldiv!(2, view(y,:,1))
100+
lmul!(α, @view(y[:,1]))
101+
end
102+
end
103+
function lmul_dim_begin!(α, d::Number, y::AbstractArray{<:Any,3})
104+
if isone(d)
105+
lmul!(α, @view(y[1,:,:]))
106+
elseif d == 2
107+
lmul!(α, @view(y[:,1,:]))
108+
else # d == 3
109+
lmul!(α, @view(y[:,:,1]))
64110
end
111+
end
112+
113+
lmul_dim_end!(α, d::Number, y::AbstractVector) = y[end] *= α
114+
function lmul_dim_end!(α, d::Number, y::AbstractMatrix)
115+
if isone(d)
116+
lmul!(α, @view(y[end,:]))
117+
else
118+
lmul!(α, @view(y[:,end]))
119+
end
120+
end
121+
function lmul_dim_end!(α, d::Number, y::AbstractArray{<:Any,3})
122+
if isone(d)
123+
lmul!(α, @view(y[end,:,:]))
124+
elseif d == 2
125+
lmul!(α, @view(y[:,end,:]))
126+
else # d == 3
127+
lmul!(α, @view(y[:,:,end]))
128+
end
129+
end
130+
131+
132+
@inline function _cheb1_rescale!(d::Number, y::AbstractArray)
133+
ldiv_dim_begin!(2, d, y)
65134
ldiv!(size(y,d), y)
66135
end
67136

68-
# TODO: higher dimensional arrays
69-
@inline function _cheb1_rescale!(d::UnitRange, y::AbstractMatrix{T}) where T
70-
@assert d == 1:2
71-
ldiv!(2, view(y,1,:))
72-
ldiv!(2, view(y,:,1))
73-
ldiv!(prod(size(y)), y)
137+
function _prod_size(sz, d)
138+
ret = 1
139+
for k in d
140+
ret *= sz[k]
141+
end
142+
ret
143+
end
144+
145+
146+
@inline function _cheb1_rescale!(d::UnitRange, y::AbstractArray)
147+
for k in d
148+
ldiv_dim_begin!(2, k, y)
149+
end
150+
ldiv!(_prod_size(size(y), d), y)
74151
end
75152

76153
function *(P::ChebyshevTransformPlan{T,1,K,true,N}, x::AbstractArray{T,N}) where {T,K,N}
@@ -90,27 +167,21 @@ function mul!(y::AbstractArray{T,N}, P::ChebyshevTransformPlan{T,1,K,false,N}, x
90167
end
91168

92169

93-
_cheb2_rescale!(_, y::AbstractVector) = (y[1] /= 2; y[end] /= 2; ldiv!(length(y)-1, y))
94170

95-
function _cheb2_rescale!(d::Number, y::AbstractMatrix{T}) where T
96-
if isone(d)
97-
ldiv!(2, @view(y[1,:]))
98-
ldiv!(2, @view(y[end,:]))
99-
else
100-
ldiv!(2, @view(y[:,1]))
101-
ldiv!(2, @view(y[:,end]))
102-
end
171+
function _cheb2_rescale!(d::Number, y::AbstractArray)
172+
ldiv_dim_begin!(2, d, y)
173+
ldiv_dim_end!(2, d, y)
103174
ldiv!(size(y,d)-1, y)
104175
end
105176

106177
# TODO: higher dimensional arrays
107-
function _cheb2_rescale!(d::UnitRange, y::AbstractMatrix{T}) where T
108-
@assert d == 1:2
109-
ldiv!(2, @view(y[1,:]))
110-
ldiv!(2, @view(y[end,:]))
111-
ldiv!(2, @view(y[:,1]))
112-
ldiv!(2, @view(y[:,end]))
113-
ldiv!(prod(size(y) .- 1), y)
178+
function _cheb2_rescale!(d::UnitRange, y::AbstractArray)
179+
for k in d
180+
ldiv_dim_begin!(2, k, y)
181+
ldiv_dim_end!(2, k, y)
182+
end
183+
184+
ldiv!(_prod_size(size(y) .- 1, d), y)
114185
end
115186

116187
function *(P::ChebyshevTransformPlan{T,2,K,true,N}, x::AbstractArray{T,N}) where {T,K,N}
@@ -200,33 +271,25 @@ end
200271
plan_ichebyshevtransform!(x::AbstractArray, dims...; kws...) = plan_ichebyshevtransform!(x, Val(1), dims...; kws...)
201272
plan_ichebyshevtransform(x::AbstractArray, dims...; kws...) = plan_ichebyshevtransform(x, Val(1), dims...; kws...)
202273

203-
@inline _icheb1_prescale!(_, x::AbstractVector) = (x[1] *= 2)
204-
@inline function _icheb1_prescale!(d::Number, x::AbstractMatrix)
205-
if isone(d)
206-
lmul!(2, view(x,1,:))
207-
else
208-
lmul!(2, view(x,:,1))
209-
end
274+
@inline function _icheb1_prescale!(d::Number, x::AbstractArray)
275+
lmul_dim_begin!(2, d, x)
210276
x
211277
end
212-
@inline function _icheb1_prescale!(d::UnitRange, x::AbstractMatrix)
213-
lmul!(2, view(x,:,1))
214-
lmul!(2, view(x,1,:))
278+
@inline function _icheb1_prescale!(d::UnitRange, x::AbstractArray)
279+
for k in d
280+
_icheb1_prescale!(k, x)
281+
end
215282
x
216283
end
217-
@inline _icheb1_postscale!(_, x::AbstractVector) = (x[1] /= 2)
218-
@inline function _icheb1_postscale!(d::Number, x::AbstractMatrix)
219-
if isone(d)
220-
ldiv!(2, view(x,1,:))
221-
else
222-
ldiv!(2, view(x,:,1))
223-
end
284+
@inline function _icheb1_postscale!(d::Number, x::AbstractArray)
285+
ldiv_dim_begin!(2, d, x)
224286
x
225287
end
226288

227-
@inline function _icheb1_postscale!(d::UnitRange, x::AbstractMatrix)
228-
ldiv!(2, view(x,1,:))
229-
ldiv!(2, view(x,:,1))
289+
@inline function _icheb1_postscale!(d::UnitRange, x::AbstractArray)
290+
for k in d
291+
_icheb1_postscale!(k, x)
292+
end
230293
x
231294
end
232295

@@ -249,40 +312,27 @@ function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,1,K,false,N},
249312
ldiv!(2^length(P.plan.region), y)
250313
end
251314

252-
@inline _icheb2_prescale!(_, x::AbstractVector) = (x[1] *= 2; x[end] *= 2)
253-
@inline function _icheb2_prescale!(d::Number, x::AbstractMatrix)
254-
if isone(d)
255-
lmul!(2, @view(x[1,:]))
256-
lmul!(2, @view(x[end,:]))
257-
else
258-
lmul!(2, @view(x[:,1]))
259-
lmul!(2, @view(x[:,end]))
260-
end
315+
@inline function _icheb2_prescale!(d::Number, x::AbstractArray)
316+
lmul_dim_begin!(2, d, x)
317+
lmul_dim_end!(2, d, x)
261318
x
262319
end
263-
@inline function _icheb2_prescale!(d::UnitRange, x::AbstractMatrix)
264-
lmul!(2, @view(x[1,:]))
265-
lmul!(2, @view(x[end,:]))
266-
lmul!(2, @view(x[:,1]))
267-
lmul!(2, @view(x[:,end]))
320+
@inline function _icheb2_prescale!(d::UnitRange, x::AbstractArray)
321+
for k in d
322+
_icheb2_prescale!(k, x)
323+
end
268324
x
269325
end
270-
@inline _icheb2_postrescale!(_, x::AbstractVector) = (x[1] /= 2; x[end] /= 2)
271-
@inline function _icheb2_postrescale!(d::Number, x::AbstractMatrix)
272-
if isone(d)
273-
ldiv!(2, @view(x[1,:]))
274-
ldiv!(2, @view(x[end,:]))
275-
else
276-
ldiv!(2, @view(x[:,1]))
277-
ldiv!(2, @view(x[:,end]))
278-
end
326+
327+
@inline function _icheb2_postrescale!(d::Number, x::AbstractArray)
328+
ldiv_dim_begin!(2, d, x)
329+
ldiv_dim_end!(2, d, x)
279330
x
280331
end
281-
@inline function _icheb2_postrescale!(d::UnitRange, x::AbstractMatrix)
282-
ldiv!(2, @view(x[1,:]))
283-
ldiv!(2, @view(x[end,:]))
284-
ldiv!(2, @view(x[:,1]))
285-
ldiv!(2, @view(x[:,end]))
332+
@inline function _icheb2_postrescale!(d::UnitRange, x::AbstractArray)
333+
for k in d
334+
_icheb2_postrescale!(k, x)
335+
end
286336
x
287337
end
288338
@inline function _icheb2_rescale!(d::Number, y::AbstractArray{T}) where T
@@ -292,7 +342,7 @@ end
292342
end
293343
@inline function _icheb2_rescale!(d::UnitRange, y::AbstractArray{T}) where T
294344
_icheb2_prescale!(d, y)
295-
lmul!(prod(convert.(T, size(y) .- 1)./2), y)
345+
lmul!(_prod_size(convert.(T, size(y) .- 1)./2, d), y)
296346
y
297347
end
298348

test/chebyshevtests.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,56 @@ using FastTransforms, Test
203203
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))
204204
end
205205

206+
@testset "tensor" begin
207+
X = randn(4,5,6)
208+
= similar(X)
209+
@testset "chebyshevtransform" begin
210+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevtransform(X[:,k,j]) end
211+
@test @inferred(chebyshevtransform(X,1)) @inferred(chebyshevtransform!(copy(X),1))
212+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevtransform(X[k,:,j]) end
213+
@test chebyshevtransform(X,2) chebyshevtransform!(copy(X),2)
214+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevtransform(X[k,j,:]) end
215+
@test chebyshevtransform(X,3) chebyshevtransform!(copy(X),3)
216+
217+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevtransform(X[:,k,j],Val(2)) end
218+
@test @inferred(chebyshevtransform(X,Val(2),1)) @inferred(chebyshevtransform!(copy(X),Val(2),1))
219+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevtransform(X[k,:,j],Val(2)) end
220+
@test chebyshevtransform(X,Val(2),2) chebyshevtransform!(copy(X),Val(2),2)
221+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevtransform(X[k,j,:],Val(2)) end
222+
@test chebyshevtransform(X,Val(2),3) chebyshevtransform!(copy(X),Val(2),3)
223+
224+
@test @inferred(chebyshevtransform(X)) @inferred(chebyshevtransform!(copy(X))) chebyshevtransform(chebyshevtransform(chebyshevtransform(X,1),2),3)
225+
@test @inferred(chebyshevtransform(X,Val(2))) @inferred(chebyshevtransform!(copy(X),Val(2))) chebyshevtransform(chebyshevtransform(chebyshevtransform(X,Val(2),1),Val(2),2),Val(2),3)
226+
end
227+
228+
@testset "ichebyshevtransform" begin
229+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevtransform(X[:,k,j]) end
230+
@test @inferred(ichebyshevtransform(X,1)) @inferred(ichebyshevtransform!(copy(X),1))
231+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevtransform(X[k,:,j]) end
232+
@test ichebyshevtransform(X,2) ichebyshevtransform!(copy(X),2)
233+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevtransform(X[k,j,:]) end
234+
@test ichebyshevtransform(X,3) ichebyshevtransform!(copy(X),3)
235+
236+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevtransform(X[:,k,j],Val(2)) end
237+
@test @inferred(ichebyshevtransform(X,Val(2),1)) @inferred(ichebyshevtransform!(copy(X),Val(2),1))
238+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevtransform(X[k,:,j],Val(2)) end
239+
@test ichebyshevtransform(X,Val(2),2) ichebyshevtransform!(copy(X),Val(2),2)
240+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevtransform(X[k,j,:],Val(2)) end
241+
@test ichebyshevtransform(X,Val(2),3) ichebyshevtransform!(copy(X),Val(2),3)
242+
243+
@test @inferred(ichebyshevtransform(X)) @inferred(ichebyshevtransform!(copy(X))) ichebyshevtransform(ichebyshevtransform(ichebyshevtransform(X,1),2),3)
244+
@test @inferred(ichebyshevtransform(X,Val(2))) @inferred(ichebyshevtransform!(copy(X),Val(2))) ichebyshevtransform(ichebyshevtransform(ichebyshevtransform(X,Val(2),1),Val(2),2),Val(2),3)
245+
246+
@test ichebyshevtransform(chebyshevtransform(X)) X
247+
@test chebyshevtransform(ichebyshevtransform(X)) X
248+
end
249+
250+
X = randn(1,1,1)
251+
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
252+
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
253+
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))
254+
end
255+
206256
@testset "Integer" begin
207257
@test chebyshevtransform([1,2,3]) == chebyshevtransform([1.,2,3])
208258
@test chebyshevtransform([1,2,3], Val(2)) == chebyshevtransform([1.,2,3], Val(2))

0 commit comments

Comments
 (0)