Skip to content

Commit 6afdbd2

Browse files
authored
Fix array nested sparse cases (#424)
* update the array nested sparse case * Update build_function.jl
1 parent 527765d commit 6afdbd2

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/build_function.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,14 @@ function _build_function(target::JuliaTarget, rhss, args...;
171171
_rhss = rhss
172172
end
173173

174-
if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
175-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
176-
elseif eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of sparse matrices
177-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2.nzval)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(_rhss)])
174+
if eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of sparse matrices
175+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2.nzval)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
176+
elseif eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
177+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)], init=Expr[])
178178
elseif eltype(rhss) <: SparseMatrixCSC # Array of sparse matrices
179-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel.nzval)]) for (i,rhsel) enumerate(_rhss)])
179+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel.nzval)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])
180180
elseif eltype(rhss) <: AbstractArray # Array of arrays
181-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init = Expr[])
181+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])
182182
elseif rhss isa SparseMatrixCSC
183183
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
184184
else

test/build_function_arrayofarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ h_sparse_arrayvec_ip!(out_2_arrayvec, input)
147147
@test out_1_arrayvec == out_2_arrayvec
148148

149149
# Arrays of Arrays of Matrices
150-
h_sparse_arrayNestedMat = [[[a 1; b 0], [0 0; 0 0]], [[b 1; a 0], [b c; 0 1]]]
150+
h_sparse_arrayNestedMat = [sparse.([[a 1; b 0], [0 0; 0 0]]), sparse.([[b 1; a 0], [b c; 0 1]])]
151151
function h_sparse_arrayNestedMat_julia(x)
152152
a, b, c = x
153153
return [sparse.([[a[1] 1; b[1] 0], [0 0; 0 0]]), sparse.([[b[1] 1; a[1] 0], [b[1] c[1]; 0 1]])]

0 commit comments

Comments
 (0)