Skip to content

Commit bfc4e71

Browse files
allow views in *, \ and contiguous lmul! and ldiv!
finish 3ba4ee3
1 parent 8a82d79 commit bfc4e71

File tree

2 files changed

+83
-35
lines changed

2 files changed

+83
-35
lines changed

examples/nonlocaldiffusion.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function evaluate_lambda(n::Integer, alpha::T, delta::T) where T
7070

7171
p = plan_jac2jac(T, n-1, zero(T), zero(T), alpha, zero(T))
7272

73-
lambda[2:end] .= p'lambda[2:end]
73+
lmul!(p', view(lambda, 2:n))
7474

7575
for i = 2:n-1
7676
lambda[i+1] = ((2i-1)*lambda[i+1] + (i-1)*lambda[i])/i

src/libfasttransforms.jl

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,16 @@ struct mpfr_t <: AbstractFloat
3333
end
3434

3535
"""
36-
`BigFloat` is a mutable struct and there is no guarantee that each entry in
37-
an `Array{BigFloat}` has unique pointers. For example, looking at the `Limb`s,
36+
`BigFloat` is a mutable struct and there is no guarantee that each entry in an
37+
`AbstractArray{BigFloat}` is unique. For example, looking at the `Limb`s,
3838
3939
Id = Matrix{BigFloat}(I, 3, 3)
4040
map(x->x.d, Id)
4141
4242
shows that the ones and the zeros all share the same pointers. If a C function
4343
assumes unicity of each datum, then the array must be renewed with a `deepcopy`.
4444
"""
45-
function renew!(x::Array{BigFloat})
45+
function renew!(x::AbstractArray{BigFloat})
4646
for i in eachindex(x)
4747
@inbounds x[i] = deepcopy(x[i])
4848
end
@@ -221,12 +221,20 @@ show(io::IO, p::FTPlan{T, 2, K}) where {T, K} = print(io, "FastTransforms plan f
221221
show(io::IO, p::FTPlan{T, 3, K}) where {T, K} = print(io, "FastTransforms plan for ", kind2string(K), " for $(p.n)×$(p.l)×$(p.m)-element array of ", T)
222222
show(io::IO, p::FTPlan{T, 2, SPHERICALISOMETRY}) where T = print(io, "FastTransforms ", kind2string(SPHERICALISOMETRY), " plan for $(p.n)×$(2p.n-1)-element array of ", T)
223223

224-
function checksize(p::FTPlan{T, 1}, x::Array{T}) where T
224+
function checksize(p::FTPlan{T, 1}, x::StridedArray{T}) where T
225225
if p.n != size(x, 1)
226226
throw(DimensionMismatch("FTPlan has dimensions $(p.n) × $(p.n), x has leading dimension $(size(x, 1))"))
227227
end
228228
end
229229

230+
function checkstrides(p::FTPlan{T, 1}, x::StridedArray{T}) where T
231+
sz = size(x)
232+
st = strides(x)
233+
if (1, cumprod(sz)...) != (st..., length(x))
234+
error("FTPlan requires unit strides, x has strides $(strides(x))")
235+
end
236+
end
237+
230238
for (N, K) in ((2, RECTDISK), (2, TRIANGLE), (3, TETRAHEDRON))
231239
@eval function checksize(p::FTPlan{T, $N, $K}, x::Array{T, $N}) where T
232240
if p.n != size(x, 1)
@@ -325,6 +333,14 @@ function checksize(p::AdjointFTPlan, x)
325333
end
326334
end
327335

336+
function checkstrides(p::AdjointFTPlan, x)
337+
try
338+
checkstrides(p.adjoint, x)
339+
catch
340+
checkstrides(p.parent, x)
341+
end
342+
end
343+
328344
function unsafe_convert(::Type{Ptr{ft_plan_struct}}, p::AdjointFTPlan)
329345
try
330346
unsafe_convert(Ptr{ft_plan_struct}, p.adjoint)
@@ -373,6 +389,14 @@ function checksize(p::TransposeFTPlan, x)
373389
end
374390
end
375391

392+
function checkstrides(p::TransposeFTPlan, x)
393+
try
394+
checkstrides(p.transpose, x)
395+
catch
396+
checkstrides(p.parent, x)
397+
end
398+
end
399+
376400
function unsafe_convert(::Type{Ptr{ft_plan_struct}}, p::TransposeFTPlan)
377401
try
378402
unsafe_convert(Ptr{ft_plan_struct}, p.transpose)
@@ -803,18 +827,21 @@ for (fJ, fC, elty) in ((:lmul!, :ft_bfmvf, :Float32),
803827
(:lmul!, :ft_bfmv , :Float64),
804828
(:ldiv!, :ft_bfsv , :Float64))
805829
@eval begin
806-
function $fJ(p::FTPlan{$elty, 1}, x::Vector{$elty})
830+
function $fJ(p::FTPlan{$elty, 1}, x::StridedVector{$elty})
807831
checksize(p, x)
832+
checkstrides(p, x)
808833
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'N', p, x)
809834
return x
810835
end
811-
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, K}}, x::Vector{$elty}) where K
836+
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, K}}, x::StridedVector{$elty}) where K
812837
checksize(p, x)
838+
checkstrides(p, x)
813839
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'T', p, x)
814840
return x
815841
end
816-
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, K}}, x::Vector{$elty}) where K
842+
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, K}}, x::StridedVector{$elty}) where K
817843
checksize(p, x)
844+
checkstrides(p, x)
818845
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'T', p, x)
819846
return x
820847
end
@@ -824,18 +851,21 @@ end
824851
for (fJ, fC, elty) in ((:lmul!, :ft_bbbfmvf, :Float32),
825852
(:lmul!, :ft_bbbfmv , :Float64))
826853
@eval begin
827-
function $fJ(p::FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}, x::Vector{$elty})
854+
function $fJ(p::FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}, x::StridedVector{$elty})
828855
checksize(p, x)
856+
checkstrides(p, x)
829857
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'N', '2', '1', p, x)
830858
return x
831859
end
832-
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::Vector{$elty})
860+
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::StridedVector{$elty})
833861
checksize(p, x)
862+
checkstrides(p, x)
834863
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'T', '1', '2', p, x)
835864
return x
836865
end
837-
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::Vector{$elty})
866+
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::StridedVector{$elty})
838867
checksize(p, x)
868+
checkstrides(p, x)
839869
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'T', '1', '2', p, x)
840870
return x
841871
end
@@ -845,18 +875,21 @@ end
845875
for (fJ, fC, elty) in ((:lmul!, :ft_mpmv, :Float64),
846876
(:ldiv!, :ft_mpsv, :Float64))
847877
@eval begin
848-
function $fJ(p::ModifiedFTPlan{$elty}, x::Vector{$elty})
878+
function $fJ(p::ModifiedFTPlan{$elty}, x::StridedVector{$elty})
849879
checksize(p, x)
880+
checkstrides(p, x)
850881
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'N', p, x)
851882
return x
852883
end
853-
function $fJ(p::AdjointFTPlan{$elty, ModifiedFTPlan{$elty}}, x::Vector{$elty})
884+
function $fJ(p::AdjointFTPlan{$elty, ModifiedFTPlan{$elty}}, x::StridedVector{$elty})
854885
checksize(p, x)
886+
checkstrides(p, x)
855887
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'T', p, x)
856888
return x
857889
end
858-
function $fJ(p::TransposeFTPlan{$elty, ModifiedFTPlan{$elty}}, x::Vector{$elty})
890+
function $fJ(p::TransposeFTPlan{$elty, ModifiedFTPlan{$elty}}, x::StridedVector{$elty})
859891
checksize(p, x)
892+
checkstrides(p, x)
860893
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}), 'T', p, x)
861894
return x
862895
end
@@ -866,18 +899,21 @@ end
866899
for (fJ, fC) in ((:lmul!, :ft_mpfr_trmv_ptr),
867900
(:ldiv!, :ft_mpfr_trsv_ptr))
868901
@eval begin
869-
function $fJ(p::FTPlan{BigFloat, 1}, x::Vector{BigFloat})
902+
function $fJ(p::FTPlan{BigFloat, 1}, x::StridedVector{BigFloat})
870903
checksize(p, x)
904+
checkstrides(p, x)
871905
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Ptr{mpfr_t}, Cint, Ptr{BigFloat}, Int32), 'N', p.n, p, p.n, renew!(x), Base.MPFR.ROUNDING_MODE[])
872906
return x
873907
end
874-
function $fJ(p::AdjointFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::Vector{BigFloat}) where K
908+
function $fJ(p::AdjointFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::StridedVector{BigFloat}) where K
875909
checksize(p, x)
910+
checkstrides(p, x)
876911
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Ptr{mpfr_t}, Cint, Ptr{BigFloat}, Int32), 'T', p.parent.n, p, p.parent.n, renew!(x), Base.MPFR.ROUNDING_MODE[])
877912
return x
878913
end
879-
function $fJ(p::TransposeFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::Vector{BigFloat}) where K
914+
function $fJ(p::TransposeFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::StridedVector{BigFloat}) where K
880915
checksize(p, x)
916+
checkstrides(p, x)
881917
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Ptr{mpfr_t}, Cint, Ptr{BigFloat}, Int32), 'T', p.parent.n, p, p.parent.n, renew!(x), Base.MPFR.ROUNDING_MODE[])
882918
return x
883919
end
@@ -889,18 +925,21 @@ for (fJ, fC, elty) in ((:lmul!, :ft_bfmmf, :Float32),
889925
(:lmul!, :ft_bfmm , :Float64),
890926
(:ldiv!, :ft_bfsm , :Float64))
891927
@eval begin
892-
function $fJ(p::FTPlan{$elty, 1}, x::Matrix{$elty})
928+
function $fJ(p::FTPlan{$elty, 1}, x::StridedMatrix{$elty})
893929
checksize(p, x)
930+
checkstrides(p, x)
894931
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'N', p, x, size(x, 1), size(x, 2))
895932
return x
896933
end
897-
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, K}}, x::Matrix{$elty}) where K
934+
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, K}}, x::StridedMatrix{$elty}) where K
898935
checksize(p, x)
936+
checkstrides(p, x)
899937
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'T', p, x, size(x, 1), size(x, 2))
900938
return x
901939
end
902-
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, K}}, x::Matrix{$elty}) where K
940+
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, K}}, x::StridedMatrix{$elty}) where K
903941
checksize(p, x)
942+
checkstrides(p, x)
904943
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'T', p, x, size(x, 1), size(x, 2))
905944
return x
906945
end
@@ -910,18 +949,21 @@ end
910949
for (fJ, fC, elty) in ((:lmul!, :ft_bbbfmmf, :Float32),
911950
(:lmul!, :ft_bbbfmm , :Float64))
912951
@eval begin
913-
function $fJ(p::FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}, x::Matrix{$elty})
952+
function $fJ(p::FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}, x::StridedMatrix{$elty})
914953
checksize(p, x)
954+
checkstrides(p, x)
915955
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'N', '2', '1', p, x, size(x, 1), size(x, 2))
916956
return x
917957
end
918-
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::Matrix{$elty})
958+
function $fJ(p::AdjointFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::StridedMatrix{$elty})
919959
checksize(p, x)
960+
checkstrides(p, x)
920961
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'T', '1', '2', p, x, size(x, 1), size(x, 2))
921962
return x
922963
end
923-
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::Matrix{$elty})
964+
function $fJ(p::TransposeFTPlan{$elty, FTPlan{$elty, 1, ASSOCIATEDJAC2JAC}}, x::StridedMatrix{$elty})
924965
checksize(p, x)
966+
checkstrides(p, x)
925967
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'T', '1', '2', p, x, size(x, 1), size(x, 2))
926968
return x
927969
end
@@ -931,18 +973,21 @@ end
931973
for (fJ, fC, elty) in ((:lmul!, :ft_mpmm, :Float64),
932974
(:ldiv!, :ft_mpsm, :Float64))
933975
@eval begin
934-
function $fJ(p::ModifiedFTPlan{$elty}, x::Matrix{$elty})
976+
function $fJ(p::ModifiedFTPlan{$elty}, x::StridedMatrix{$elty})
935977
checksize(p, x)
978+
checkstrides(p, x)
936979
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'N', p, x, size(x, 1), size(x, 2))
937980
return x
938981
end
939-
function $fJ(p::AdjointFTPlan{$elty, ModifiedFTPlan{$elty}}, x::Matrix{$elty})
982+
function $fJ(p::AdjointFTPlan{$elty, ModifiedFTPlan{$elty}}, x::StridedMatrix{$elty})
940983
checksize(p, x)
984+
checkstrides(p, x)
941985
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'T', p, x, size(x, 1), size(x, 2))
942986
return x
943987
end
944-
function $fJ(p::TransposeFTPlan{$elty, ModifiedFTPlan{$elty}}, x::Matrix{$elty})
988+
function $fJ(p::TransposeFTPlan{$elty, ModifiedFTPlan{$elty}}, x::StridedMatrix{$elty})
945989
checksize(p, x)
990+
checkstrides(p, x)
946991
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Ptr{ft_plan_struct}, Ptr{$elty}, Cint, Cint), 'T', p, x, size(x, 1), size(x, 2))
947992
return x
948993
end
@@ -952,18 +997,21 @@ end
952997
for (fJ, fC) in ((:lmul!, :ft_mpfr_trmm_ptr),
953998
(:ldiv!, :ft_mpfr_trsm_ptr))
954999
@eval begin
955-
function $fJ(p::FTPlan{BigFloat, 1}, x::Matrix{BigFloat})
1000+
function $fJ(p::FTPlan{BigFloat, 1}, x::StridedMatrix{BigFloat})
9561001
checksize(p, x)
1002+
checkstrides(p, x)
9571003
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Ptr{mpfr_t}, Cint, Ptr{BigFloat}, Cint, Cint, Int32), 'N', p.n, p, p.n, renew!(x), size(x, 1), size(x, 2), Base.MPFR.ROUNDING_MODE[])
9581004
return x
9591005
end
960-
function $fJ(p::AdjointFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::Matrix{BigFloat}) where K
1006+
function $fJ(p::AdjointFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::StridedMatrix{BigFloat}) where K
9611007
checksize(p, x)
1008+
checkstrides(p, x)
9621009
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Ptr{mpfr_t}, Cint, Ptr{BigFloat}, Cint, Cint, Int32), 'T', p.parent.n, p, p.parent.n, renew!(x), size(x, 1), size(x, 2), Base.MPFR.ROUNDING_MODE[])
9631010
return x
9641011
end
965-
function $fJ(p::TransposeFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::Matrix{BigFloat}) where K
1012+
function $fJ(p::TransposeFTPlan{BigFloat, FTPlan{BigFloat, 1, K}}, x::StridedMatrix{BigFloat}) where K
9661013
checksize(p, x)
1014+
checkstrides(p, x)
9671015
ccall(($(string(fC)), libfasttransforms), Cvoid, (Cint, Cint, Ptr{mpfr_t}, Cint, Ptr{BigFloat}, Cint, Cint, Int32), 'T', p.parent.n, p, p.parent.n, renew!(x), size(x, 1), size(x, 2), Base.MPFR.ROUNDING_MODE[])
9681016
return x
9691017
end
@@ -1077,12 +1125,12 @@ function execute_sph_reflection!(p::FTPlan{Float64, 2, SPHERICALISOMETRY}, w, x:
10771125
end
10781126
execute_sph_reflection!(p::FTPlan{Float64, 2, SPHERICALISOMETRY}, w1, w2, w3, x::Matrix{Float64}) = execute_sph_reflection!(p, ft_reflection(w1, w2, w3), x)
10791127

1080-
*(p::FTPlan{T}, x::Array{Complex{T}}) where T = lmul!(p, deepcopy(x))
1081-
*(p::AdjointFTPlan{T}, x::Array{Complex{T}}) where T = lmul!(p, deepcopy(x))
1082-
*(p::TransposeFTPlan{T}, x::Array{Complex{T}}) where T = lmul!(p, deepcopy(x))
1083-
\(p::FTPlan{T}, x::Array{Complex{T}}) where T = ldiv!(p, deepcopy(x))
1084-
\(p::AdjointFTPlan{T}, x::Array{Complex{T}}) where T = ldiv!(p, deepcopy(x))
1085-
\(p::TransposeFTPlan{T}, x::Array{Complex{T}}) where T = ldiv!(p, deepcopy(x))
1128+
*(p::FTPlan{T}, x::AbstractArray{Complex{T}}) where T = lmul!(p, Array(x))
1129+
*(p::AdjointFTPlan{T}, x::AbstractArray{Complex{T}}) where T = lmul!(p, Array(x))
1130+
*(p::TransposeFTPlan{T}, x::AbstractArray{Complex{T}}) where T = lmul!(p, Array(x))
1131+
\(p::FTPlan{T}, x::AbstractArray{Complex{T}}) where T = ldiv!(p, Array(x))
1132+
\(p::AdjointFTPlan{T}, x::AbstractArray{Complex{T}}) where T = ldiv!(p, Array(x))
1133+
\(p::TransposeFTPlan{T}, x::AbstractArray{Complex{T}}) where T = ldiv!(p, Array(x))
10861134

10871135
for fJ in (:lmul!, :ldiv!)
10881136
@eval begin

0 commit comments

Comments
 (0)