Skip to content

Commit 6621387

Browse files
authored
safer inbounds in some helper functions (#429)
* safer inbounds in some helper functions * call sum in alternatingsum * Add tests * Add second negateeven test
1 parent e1384d7 commit 6621387

File tree

2 files changed

+47
-34
lines changed

2 files changed

+47
-34
lines changed

src/LinearAlgebra/helper.jl

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@ import Base: chop
88
@inline dot(M::Int,a::Ptr{T},incx::Int,b::Ptr{T},incy::Int) where {T<:Union{ComplexF64,ComplexF32}} =
99
BLAS.dotc(M,a,incx,b,incy)
1010

11-
dotu(f::StridedVector{T},g::StridedVector{T}) where {T<:Union{ComplexF32,ComplexF64}} =
11+
dotu(f::StridedVector{T}, g::StridedVector{T}) where {T<:Union{ComplexF32,ComplexF64}} =
1212
BLAS.dotu(f,g)
13-
dotu(f::AbstractVector{Complex{Float64}},g::AbstractVector{N}) where {N<:Real} = dot(conj(f),g)
14-
dotu(f::AbstractVector{N},g::AbstractVector{T}) where {N<:Real,T<:Number} = dot(f,g)
13+
dotu(f::AbstractVector{<:Complex}, g::AbstractVector{<:Real}) = dot(conj(f),g)
14+
dotu(f::AbstractVector{<:Real}, g::AbstractVector{<:Real}) = dot(f,g)
15+
function dotu(f::AbstractVector{<:Number}, g::AbstractVector{<:Number})
16+
Base.require_one_based_indexing(f)
17+
axes(f) == axes(g) || throw(ArgumentError("vectors must have the same indices"))
18+
mapreduce(*, +, f, g)
19+
end
1520

1621

1722
normalize!(w::AbstractVector) = rmul!(w,inv(norm(w)))
@@ -100,8 +105,8 @@ end
100105
scal!(n::Integer,cst::BlasFloat,ret::DenseArray{T},k::Integer) where {T<:BlasFloat} =
101106
BLAS.scal!(n,strictconvert(T,cst),ret,k)
102107

103-
function scal!(n::Integer,cst::Number,ret::AbstractArray,k::Integer)
104-
@assert k*n length(ret)
108+
@inline function scal!(n::Integer,cst::Number,ret::AbstractArray,k::Integer)
109+
@boundscheck checkbounds(ret, 1:(k*(n-1)+1))
105110
@simd for j=1:k:k*(n-1)+1
106111
@inbounds ret[j] *= cst
107112
end
@@ -115,6 +120,7 @@ scal!(cst::Number,v::AbstractArray) = scal!(length(v),cst,v,1)
115120
# Helper routines
116121

117122
function reverseeven!(x::AbstractVector)
123+
Base.require_one_based_indexing(x)
118124
n = length(x)
119125
if iseven(n)
120126
@inbounds @simd for k=2:2:n÷2
@@ -129,14 +135,15 @@ function reverseeven!(x::AbstractVector)
129135
end
130136

131137
function negateeven!(x::AbstractVector)
132-
@inbounds @simd for k = 2:2:length(x)
133-
x[k] *= -1
134-
end
138+
Base.require_one_based_indexing(x)
139+
v = view(x, 2:2:length(x))
140+
v .*= -1
135141
x
136142
end
137143

138144
#checkerboard, same as applying negativeeven! to all rows then all columns
139145
function negateeven!(X::AbstractMatrix)
146+
Base.require_one_based_indexing(X)
140147
for j = 1:2:size(X,2)
141148
@inbounds @simd for k = 2:2:size(X,1)
142149
X[k,j] *= -1
@@ -155,21 +162,15 @@ const alternatesign! = negateeven!
155162
alternatesign(v::AbstractVector) = alternatesign!(copy(v))
156163

157164
function alternatingsum(v::AbstractVector)
158-
ret = zero(eltype(v))
159-
s = 1
160-
@inbounds for k=1:length(v)
161-
ret+=s*v[k]
162-
s*=-1
163-
end
164-
165-
ret
165+
sum(((a,b),) -> a*b, zip(v, Iterators.cycle((1,-1))))
166166
end
167167

168168
# Sum Hadamard product of vectors up to minimum over lengths
169169
function mindotu(a::AbstractVector,b::AbstractVector)
170-
ret,m = zero(promote_type(eltype(a),eltype(b))),min(length(a),length(b))
171-
@inbounds @simd for i=m:-1:1 ret += a[i]*b[i] end
172-
ret
170+
Base.require_one_based_indexing(a)
171+
Base.require_one_based_indexing(b)
172+
m = min(length(a), length(b))
173+
dotu(view(a, 1:m), view(b, 1:m))
173174
end
174175

175176

@@ -240,25 +241,20 @@ function pad(A::AbstractMatrix,n::Integer,m::Integer)
240241
Base.require_one_based_indexing(A)
241242
T=eltype(A)
242243
if n <= size(A,1) && m <= size(A,2)
243-
A[1:n,1:m]
244-
elseif n==0 || m==0
245-
Matrix{T}(undef,n,m) #fixes weird julia bug when T==None
244+
strictconvert(Matrix{T}, A[1:n,1:m])
246245
else
247246
ret = Matrix{T}(undef,n,m)
248247
minn=min(n,size(A,1))
249248
minm=min(m,size(A,2))
250-
for k=1:minn,j=1:minm
251-
@inbounds ret[k,j]=A[k,j]
252-
end
253-
for k=minn+1:n,j=1:minm
254-
@inbounds ret[k,j]=zero(T)
255-
end
256-
for k=1:n,j=minm+1:m
257-
@inbounds ret[k,j]=zero(T)
258-
end
259-
for k=minn+1:n,j=minm+1:m
260-
@inbounds ret[k,j]=zero(T)
261-
end
249+
250+
cinds = CartesianIndices((1:minn, 1:minm))
251+
copyto!(ret, cinds, A, cinds)
252+
253+
cinds = CartesianIndices((minn+1:n, 1:minm))
254+
ret[cinds] .= zero(T)
255+
256+
cinds = CartesianIndices((axes(ret,1), minm+1:m))
257+
ret[cinds] .= zero(T)
262258

263259
ret
264260
end

test/runtests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,23 @@ end
108108

109109
@test @inferred(ApproxFunBase.flipsign(2, 0im)) == 2
110110

111+
@testset "mindotu" begin
112+
@test ApproxFunBase.mindotu(Float64[1,2], Float64[1,2,3]) == sum([1,2] .* [1,2])
113+
@test ApproxFunBase.mindotu([1,2], [1,2,3]) == sum([1,2] .* [1,2])
114+
@test ApproxFunBase.mindotu(ComplexF64[1+im,2+4im], Float64[1,2,3]) == sum([1+im,2+4im] .* [1,2])
115+
@test ApproxFunBase.mindotu(Float64[1,2,3], ComplexF64[1+im,2+4im]) == sum([1+im,2+4im] .* [1,2])
116+
@test ApproxFunBase.mindotu(ComplexF64[1+2im,2+5im,3+2im], ComplexF64[1+im,2+4im]) ==
117+
sum(ComplexF64[1+2im,2+5im] .* ComplexF64[1+im,2+4im])
118+
end
119+
120+
@testset "negateeven!" begin
121+
v = [1,2,3,4,5]
122+
ApproxFunBase.negateeven!(v)
123+
@test v == [1,-2,3,-4,5]
124+
ApproxFunBase.negateeven!(v)
125+
@test v == [1,2,3,4,5]
126+
end
127+
111128
# TODO: Tensorizer tests
112129
end
113130

0 commit comments

Comments
 (0)