Skip to content

Commit 367adf9

Browse files
authored
safer inbounds in RaggedMatrix (#431)
* safer inbounds in RaggedMatrix * fix RaggedMatrix constructor colstop checks * propagate_inbounds in get and setindex
1 parent cb25a5f commit 367adf9

File tree

1 file changed

+47
-43
lines changed

1 file changed

+47
-43
lines changed

src/LinearAlgebra/RaggedMatrix.jl

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# FiniteRange gives the nonzero entries in a row/column
22
struct FiniteRange end
33

4-
getindex(A::AbstractMatrix,::Type{FiniteRange},j::Integer) = A[1:colstop(A,j),j]
4+
getindex(A::AbstractMatrix,::Type{FiniteRange},j::Integer) = A[colrange(A,j),j]
55
getindex(A::AbstractMatrix,k::Integer,::Type{FiniteRange}) = A[k,1:rowstop(A,k)]
66

77
const = FiniteRange
@@ -32,40 +32,48 @@ RaggedMatrix{T}(::UndefInitializer, m::Int, colns::AbstractVector{Int}) where {T
3232
RaggedMatrix(Vector{T}(undef, sum(colns)),Int[1;1 .+ cumsum(colns)],m)
3333

3434

35-
Base.size(A::RaggedMatrix) = (A.m,length(A.cols)-1)
35+
size(A::RaggedMatrix) = (A.m,length(A.cols)-1)
3636

3737
colstart(A::RaggedMatrix,j::Integer) = 1
3838
colstop(A::RaggedMatrix,j::Integer) = min(A.cols[j+1]-A.cols[j],size(A,1))
3939

40-
@inline function incol(A, k, j, ind = A.cols[j]+k-1)
40+
Base.@propagate_inbounds function incol(A, k, j, ind = A.cols[j]+k-1)
4141
ind < A.cols[j+1]
4242
end
4343

44-
function getindex(A::RaggedMatrix,k::Int,j::Int)
45-
if k>size(A,1) || k < 1 || j>size(A,2) || j < 1
44+
Base.@propagate_inbounds function incols_getindex(A::RaggedMatrix, k::Int, j::Int, ind = A.cols[j]+k-1)
45+
A.data[ind]
46+
end
47+
Base.@propagate_inbounds function incols_setindex!(A::RaggedMatrix, v, k::Int, j::Int, ind = A.cols[j]+k-1)
48+
A.data[ind] = v
49+
A
50+
end
51+
52+
Base.@propagate_inbounds function getindex(A::RaggedMatrix,k::Int,j::Int)
53+
@boundscheck if k>size(A,1) || k < 1 || j>size(A,2) || j < 1
4654
throw(BoundsError(A,(k,j)))
4755
end
4856

4957
ind = A.cols[j]+k-1
5058
if incol(A, k, j, ind)
51-
A.data[ind]
59+
incols_getindex(A, k, j, ind)
5260
else
5361
zero(eltype(A))
5462
end
5563
end
5664

57-
function Base.setindex!(A::RaggedMatrix,v,k::Int,j::Int)
58-
if k>size(A,1) || k < 1 || j>size(A,2) || j < 1
65+
Base.@propagate_inbounds function setindex!(A::RaggedMatrix,v,k::Int,j::Int)
66+
@boundscheck if k>size(A,1) || k < 1 || j>size(A,2) || j < 1
5967
throw(BoundsError(A,(k,j)))
6068
end
6169

6270
ind = A.cols[j]+k-1
6371
if incol(A, k, j, ind)
64-
A.data[ind]=v
72+
incols_setindex!(A, v, k, j, ind)
6573
elseif v 0
66-
throw(BoundsError(A,(k,j)))
74+
throw(ArgumentError("Can't set index $((k,j)) of a RaggedMatrix to a non-zero value"))
6775
end
68-
v
76+
A
6977
end
7078

7179
convert(::Type{RaggedMatrix{T}}, R::RaggedMatrix{T}) where T = R
@@ -75,40 +83,32 @@ convert(::Type{RaggedMatrix{T}}, R::RaggedMatrix) where T =
7583

7684

7785
function convert(::Type{Matrix{T}}, A::RaggedMatrix) where T
78-
ret = zeros(T,size(A,1),size(A,2))
79-
for j=1:size(A,2)
80-
ret[1:colstop(A,j),j] = view(A,1:colstop(A,j),j)
86+
ret = zeros(T, size(A))
87+
@inbounds for j in axes(A,2), k in colrange(A,j)
88+
v = incols_getindex(A, k, j)
89+
ret[k,j] = v
8190
end
8291
ret
8392
end
8493

8594
convert(::Type{Matrix}, A::RaggedMatrix) = Matrix{eltype(A)}(A)
8695

87-
function convert(::Type{RaggedMatrix{T}}, B::BandedMatrix) where T
88-
l = bandwidth(B,1)
89-
ret = RaggedMatrix(Zeros{T}(size(B)), Int[colstop(B,j) for j=1:size(B,2)])
90-
for j=1:size(B,2),k=colrange(B,j)
91-
ret[k,j] = B[k,j]
92-
end
93-
ret
94-
end
95-
9696
convert(::Type{RaggedMatrix}, B::BandedMatrix) = RaggedMatrix{eltype(B)}(B)
9797

9898
function convert(::Type{RaggedMatrix{T}}, B::AbstractMatrix) where T
99-
ret = RaggedMatrix(Zeros{T}(size(B)), Int[colstop(B,j) for j=1:size(B,2)])
100-
for j=1:size(B,2), k=colrange(B,j)
101-
ret[k,j] = B[k,j]
99+
ret = RaggedMatrix(Zeros{T}(size(B)), Int[colstop(B,j) for j=axes(B,2)])
100+
@inbounds for j in axes(B,2), k in colrange(B,j)
101+
incols_setindex!(ret, B[k,j], k, j)
102102
end
103103
ret
104104
end
105105

106-
convert(::Type{RaggedMatrix}, B::AbstractMatrix) = RaggedMatrix{eltype(B)}(B)
106+
convert(::Type{RaggedMatrix}, B::AbstractMatrix) = strictconvert(RaggedMatrix{eltype(B)}, B)
107107

108108
RaggedMatrix(B::AbstractMatrix) = strictconvert(RaggedMatrix, B)
109109
RaggedMatrix{T}(B::AbstractMatrix) where T = strictconvert(RaggedMatrix{T}, B)
110110

111-
Base.similar(B::RaggedMatrix,::Type{T}) where {T} = RaggedMatrix(similar(B.data, T),copy(B.cols),B.m)
111+
similar(B::RaggedMatrix,::Type{T}) where {T} = RaggedMatrix(similar(B.data, T),copy(B.cols),B.m)
112112

113113
for (op,bop) in ((:(Base.rand), :rrand),)
114114
@eval begin
@@ -126,9 +126,13 @@ function RaggedMatrix{T}(Z::Zeros, colns::AbstractVector{Int}) where {T}
126126
end
127127

128128
function RaggedMatrix{T}(A::AbstractMatrix, colns::AbstractVector{Int}) where T
129+
Base.require_one_based_indexing(A)
130+
Base.require_one_based_indexing(colns)
129131
ret = RaggedMatrix{T}(undef, size(A,1), colns)
130-
@inbounds for j = 1:length(colns), k = 1:colns[j]
131-
ret[k,j] = A[k,j]
132+
(length(colns) == size(A,2) && all(<=(size(A,1)), colns)) ||
133+
throw(ArgumentError("column stops $colns incompatible with input matrix of size $(size(A))"))
134+
for j in axes(A,2), k = 1:colns[j]
135+
@inbounds incols_setindex!(ret, A[k,j], k, j)
132136
end
133137
ret
134138
end
@@ -149,7 +153,7 @@ function mul!(y::Vector, A::RaggedMatrix, b::Vector)
149153
end
150154
T=eltype(y)
151155
fill!(y,zero(T))
152-
for j=1:m
156+
for j in axes(A,2)
153157
kr=A.cols[j]:A.cols[j+1]-1
154158
axpy!(b[j],view(A.data,kr),view(y,1:length(kr)))
155159
end
@@ -165,7 +169,7 @@ function axpy!(a, X::RaggedMatrix, Y::RaggedMatrix)
165169
if X.cols == Y.cols
166170
axpy!(a,X.data,Y.data)
167171
else
168-
for j = 1:size(X,2)
172+
for j = axes(X,2)
169173
Xn = colstop(X,j)
170174
Yn = colstop(Y,j)
171175
if Xn > Yn # check zeros otherwise
@@ -174,8 +178,8 @@ function axpy!(a, X::RaggedMatrix, Y::RaggedMatrix)
174178
end
175179
end
176180
cs = min(Xn,Yn)
177-
axpy!(a,view(X.data,X.cols[j]:X.cols[j]+cs-1),
178-
view(Y.data,Y.cols[j]:Y.cols[j]+cs-1))
181+
axpy!(a, view(X.data, range(X.cols[j], length=cs)),
182+
view(Y.data, range(Y.cols[j], length=cs)))
179183
end
180184
end
181185
Y
@@ -196,7 +200,7 @@ function axpy!(a,X::RaggedMatrix,
196200
ksh = first(parentindices(Y)[1]) - 1 # how much to shift
197201
jsh = first(parentindices(Y)[2]) - 1 # how much to shift
198202

199-
for j=1:size(X,2)
203+
for j=axes(X,2)
200204
cx=colstop(X,j)
201205
cy=colstop(Y,j)
202206
if cx > cy
@@ -205,7 +209,7 @@ function axpy!(a,X::RaggedMatrix,
205209
throw(BoundsError("Trying to add a non-zero to a zero."))
206210
end
207211
end
208-
kr = X.cols[j]:X.cols[j]+cy-1
212+
kr = range(X.cols[j], length=cy)
209213
else
210214
kr = X.cols[j]:X.cols[j+1]-1
211215
end
@@ -222,7 +226,7 @@ end
222226
function *(A::RaggedMatrix,B::RaggedMatrix)
223227
cols = zeros(Int,size(B,2))
224228
T = promote_type(eltype(A),eltype(B))
225-
for j=1:size(B,2),k=1:colstop(B,j)
229+
for j=axes(B,2),k=colrange(B,j)
226230
cols[j] = max(cols[j],colstop(A,k))
227231
end
228232

@@ -232,23 +236,23 @@ end
232236
function unsafe_mul!(Y::RaggedMatrix,A::RaggedMatrix,B::RaggedMatrix)
233237
fill!(Y.data,0)
234238

235-
for j=1:size(B,2),k=1:colstop(B,j)
236-
axpy!(B[k,j], view(A,1:colstop(A,k),k),
237-
view(Y.data,Y.cols[j] .- 1 .+ (1:colstop(A,k))))
239+
for j=axes(B,2),k=colrange(B,j)
240+
axpy!(B[k,j], view(A,colrange(A,k),k),
241+
view(Y.data,Y.cols[j] .- 1 .+ (colrange(A,k))))
238242
end
239243

240244
Y
241245
end
242246

243247
function mul!(Y::RaggedMatrix,A::RaggedMatrix,B::RaggedMatrix)
244-
for j=1:size(B,2)
248+
for j=axes(B,2)
245249
col = 0
246-
for k=1:colstop(B,j)
250+
for k=colrange(B,j)
247251
col = max(col,colstop(A,k))
248252
end
249253

250254
if col > colstop(Y,j)
251-
throw(BoundsError())
255+
throw(BoundsError(Y, (col,j)))
252256
end
253257
end
254258

0 commit comments

Comments
 (0)