Skip to content

Commit aea91f4

Browse files
authored
Fix MTK.build_function on "arrays of sparse matrices" and "arrays of arrays of sparse matrices" (#423)
* Update build_function.jl * update related test
1 parent 098cef4 commit aea91f4

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/build_function.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ function _build_function(target::JuliaTarget, rhss, args...;
173173

174174
if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
175175
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
176-
elseif eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of arrays
177-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(_rhss)])
176+
elseif eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of sparse matrices
177+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2.nzval)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(_rhss)])
178178
elseif eltype(rhss) <: SparseMatrixCSC # Array of sparse matrices
179-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)])
179+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel.nzval)]) for (i,rhsel) enumerate(_rhss)])
180180
elseif eltype(rhss) <: AbstractArray # Array of arrays
181181
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init = Expr[])
182182
elseif rhss isa SparseMatrixCSC

test/build_function_arrayofarray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ julia_sparse_arraymat = h_sparse_arraymat_julia(input)
114114
mtk_sparse_arraymat = h_sparse_arraymat_oop(input)
115115
@test_broken julia_sparse_arraymat == mtk_sparse_arraymat
116116
h_sparse_arraymat_julia!(out_1_arraymat, input)
117-
@test_broken h_sparse_arraymat_ip!(out_2_arraymat, input)
118-
@test_broken out_1_arraymat == out_2_arraymat
117+
h_sparse_arraymat_ip!(out_2_arraymat, input)
118+
@test out_1_arraymat == out_2_arraymat
119119

120120
# Array of 1D Vectors
121121
h_sparse_arrayvec = sparse.([[a, 0, c], [0, 0, 0], [1, a, b]])

0 commit comments

Comments
 (0)