Skip to content

Commit 89cf39a

Browse files
authored
Merge pull request #404 from caryan/iszero-for-operations
provide iszero for Operation to help sparse matrix addition and multiplication
2 parents c3d4cbe + 402f7cc commit 89cf39a

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

src/operations.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ Base.isequal(::Variable , ::Operation) = false
5454
Base.isequal(::Operation, ::Constant ) = false
5555
Base.isequal(::Constant , ::Operation) = false
5656

57+
# provide iszero for Operations to help sparse addition and multiplication
58+
# e.g. we want to tell the sparse library that iszero(zero(Operation) + zero(Operation)) == true
59+
Base.iszero(x::Operation) = (_x = simplify(x); _x isa Constant && iszero(_x.value))
60+
5761
Base.show(io::IO, O::Operation) = print(io, convert(Expr, O))
5862

5963
# For inv
@@ -73,12 +77,4 @@ Base.convert(::Type{Expr},x::Operation) = Expr(x)
7377
Base.promote_rule(::Type{<:Constant}, ::Type{<:Operation}) = Operation
7478
Base.promote_rule(::Type{<:Operation}, ::Type{<:Constant}) = Operation
7579

76-
# Fix Sparse MatMul
77-
Base.:*(A::SparseMatrixCSC{Operation,S}, x::StridedVector{Operation}) where {S} =
78-
(T = Operation; mul!(similar(x, T, A.m), A, x, true, false))
79-
Base.:*(A::SparseMatrixCSC{Tx,S}, x::StridedVector{Operation}) where {Tx,S} =
80-
(T = LinearAlgebra.promote_op(LinearAlgebra.matprod, Operation, Tx); mul!(similar(x, T, A.m), A, x, true, false))
81-
Base.:*(A::SparseMatrixCSC{Operation,S}, x::StridedVector{Tx}) where {Tx,S} =
82-
(T = LinearAlgebra.promote_op(LinearAlgebra.matprod, Operation, Tx); mul!(similar(x, T, A.m), A, x, true, false))
83-
8480
LinearAlgebra.lu(O::AbstractMatrix{<:Operation};kwargs...) = lu(O,Val(false);kwargs...)

src/variables.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ Get the value of a [`ModelingToolkit.Constant`](@ref).
8989
"""
9090
Base.get(c::Constant) = c.value
9191

92-
Base.iszero(ex::Expression) = isa(ex, Constant) && iszero(ex.value)
92+
Base.iszero(c::Constant) = iszero(c.value)
93+
9394
Base.isone(ex::Expression) = isa(ex, Constant) && isone(ex.value)
9495

9596
# Variables use isequal for equality since == is an Operation

test/operation_overloads.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,34 @@
1-
using ModelingToolkit
2-
using LinearAlgebra
3-
@variables a,b,c,d
4-
X = [a b;c d]
5-
det(X)
6-
lu(X)
7-
inv(X)
1+
using ModelingToolkit
2+
using LinearAlgebra
3+
using SparseArrays: sparse
4+
using Test
5+
6+
@variables a,b,c,d
7+
8+
# test some matrix operations don't throw errors
9+
X = [a b;c d]
10+
det(X)
11+
lu(X)
12+
inv(X)
13+
14+
# test operations with sparse arrays and Operations
15+
# note `isequal` instead of `==` because `==` would give another Operation
16+
17+
# test that we can create a sparse array of Operation
18+
Oarray = zeros(Operation, 2,2)
19+
Oarray[2,2] = a
20+
@test isequal(sparse(Oarray), sparse([2], [2], [a]))
21+
22+
# test Operation * sparse
23+
@test isequal(a * sparse([2], [2], [1]), sparse([2], [2], [a * 1]))
24+
25+
# test sparse{Operation} + sparse{Operation}
26+
A = sparse([2], [2], [a])
27+
B = sparse([2], [2], [b])
28+
@test isequal(A + B, sparse([2], [2], [a+b]))
29+
30+
# test sparse{Operation} * sparse{Operation}
31+
C = sparse([1, 2], [2, 1], [c, c])
32+
D = sparse([1, 2], [2, 1], [d, d])
33+
34+
@test isequal(C * D, sparse([1,2], [1,2], [c * d, c * d]))

0 commit comments

Comments
 (0)