Skip to content

Commit 2044168

Browse files
authored
Merge pull request #23 from SymbolicML/additional-utils
Additional utilities for identity functions, `+`, and `-`
2 parents edca829 + 2412d55 commit 2044168

File tree

3 files changed

+112
-25
lines changed

3 files changed

+112
-25
lines changed

src/math.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,28 @@ Base.:*(l::Dimensions, r::Dimensions) = @map_dimensions(+, l, r)
22
Base.:*(l::Quantity, r::Quantity) = Quantity(l.value * r.value, l.dimensions * r.dimensions)
33
Base.:*(l::Quantity, r::Dimensions) = Quantity(l.value, l.dimensions * r)
44
Base.:*(l::Dimensions, r::Quantity) = Quantity(r.value, l * r.dimensions)
5-
Base.:*(l::Quantity, r::Number) = Quantity(l.value * r, l.dimensions)
6-
Base.:*(l::Number, r::Quantity) = Quantity(l * r.value, r.dimensions)
7-
Base.:*(l::Dimensions, r::Number) = Quantity(r, l)
8-
Base.:*(l::Number, r::Dimensions) = Quantity(l, r)
5+
Base.:*(l::Quantity, r) = Quantity(l.value * r, l.dimensions)
6+
Base.:*(l, r::Quantity) = Quantity(l * r.value, r.dimensions)
7+
Base.:*(l::Dimensions, r) = Quantity(r, l)
8+
Base.:*(l, r::Dimensions) = Quantity(l, r)
99

1010
Base.:/(l::Dimensions, r::Dimensions) = @map_dimensions(-, l, r)
1111
Base.:/(l::Quantity, r::Quantity) = Quantity(l.value / r.value, l.dimensions / r.dimensions)
1212
Base.:/(l::Quantity, r::Dimensions) = Quantity(l.value, l.dimensions / r)
1313
Base.:/(l::Dimensions, r::Quantity) = Quantity(inv(r.value), l / r.dimensions)
14-
Base.:/(l::Quantity, r::Number) = Quantity(l.value / r, l.dimensions)
15-
Base.:/(l::Number, r::Quantity) = l * inv(r)
16-
Base.:/(l::Dimensions, r::Number) = Quantity(inv(r), l)
17-
Base.:/(l::Number, r::Dimensions) = Quantity(l, inv(r))
14+
Base.:/(l::Quantity, r) = Quantity(l.value / r, l.dimensions)
15+
Base.:/(l, r::Quantity) = l * inv(r)
16+
Base.:/(l::Dimensions, r) = Quantity(inv(r), l)
17+
Base.:/(l, r::Dimensions) = Quantity(l, inv(r))
1818

1919
Base.:+(l::Quantity, r::Quantity) = dimension(l) == dimension(r) ? Quantity(l.value + r.value, l.dimensions) : throw(DimensionError(l, r))
20-
Base.:-(l::Quantity, r::Quantity) = dimension(l) == dimension(r) ? Quantity(l.value - r.value, l.dimensions) : throw(DimensionError(l, r))
20+
Base.:-(l::Quantity) = Quantity(-l.value, l.dimensions)
21+
Base.:-(l::Quantity, r::Quantity) = l + (-r)
22+
23+
Base.:+(l::Quantity, r) = dimension(l) == dimension(r) ? Quantity(l.value + r, l.dimensions) : throw(DimensionError(l, r))
24+
Base.:+(l, r::Quantity) = dimension(l) == dimension(r) ? Quantity(l + r.value, r.dimensions) : throw(DimensionError(l, r))
25+
Base.:-(l::Quantity, r) = l + (-r)
26+
Base.:-(l, r::Quantity) = l + (-r)
2127

2228
_pow(l::Dimensions, r) = @map_dimensions(Base.Fix1(*, r), l)
2329
_pow(l::Quantity{T}, r) where {T} = Quantity(l.value^r, _pow(l.dimensions, r))

src/utils.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Base.iszero(q::Quantity) = iszero(q.value)
3737
Base.getindex(d::Dimensions, k::Symbol) = getfield(d, k)
3838
Base.:(==)(l::Dimensions, r::Dimensions) = @all_dimensions(==, l, r)
3939
Base.:(==)(l::Quantity, r::Quantity) = l.value == r.value && l.dimensions == r.dimensions
40+
Base.:(==)(l, r::Quantity) = ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
41+
Base.:(==)(l::Quantity, r) = ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
42+
Base.isless(l::Quantity, r::Quantity) = dimension(l) == dimension(r) ? isless(ustrip(l), ustrip(r)) : throw(DimensionError(l, r))
43+
Base.isless(l::Quantity, r) = dimension(l) == dimension(r) ? isless(ustrip(l), r) : throw(DimensionError(l, r))
44+
Base.isless(l, r::Quantity) = dimension(l) == dimension(r) ? isless(l, ustrip(r)) : throw(DimensionError(l, r))
4045
Base.isapprox(l::Quantity, r::Quantity; kws...) = isapprox(l.value, r.value; kws...) && l.dimensions == r.dimensions
4146
Base.length(::Dimensions) = 1
4247
Base.length(::Quantity) = 1
@@ -45,16 +50,26 @@ Base.iterate(::Dimensions, ::Nothing) = nothing
4550
Base.iterate(q::Quantity) = (q, nothing)
4651
Base.iterate(::Quantity, ::Nothing) = nothing
4752

48-
Base.zero(::Type{Quantity{T,R}}) where {T,R} = Quantity(zero(T), R)
53+
# Multiplicative identities:
4954
Base.one(::Type{Quantity{T,R}}) where {T,R} = Quantity(one(T), R)
50-
Base.one(::Type{Dimensions{R}}) where {R} = Dimensions{R}()
51-
52-
Base.zero(::Type{Quantity{T}}) where {T} = zero(Quantity{T,DEFAULT_DIM_TYPE})
5355
Base.one(::Type{Quantity{T}}) where {T} = one(Quantity{T,DEFAULT_DIM_TYPE})
54-
55-
Base.zero(::Type{Quantity}) = zero(Quantity{DEFAULT_VALUE_TYPE})
5656
Base.one(::Type{Quantity}) = one(Quantity{DEFAULT_VALUE_TYPE})
57+
Base.one(::Type{Dimensions{R}}) where {R} = Dimensions{R}()
5758
Base.one(::Type{Dimensions}) = one(Dimensions{DEFAULT_DIM_TYPE})
59+
Base.one(q::Quantity) = Quantity(one(ustrip(q)), one(dimension(q)))
60+
Base.one(d::Dimensions) = one(typeof(d))
61+
62+
# Additive identities:
63+
Base.zero(q::Quantity) = Quantity(zero(ustrip(q)), dimension(q))
64+
Base.zero(::Dimensions) = error("There is no such thing as an additive identity for a `Dimensions` object, as + is only defined for `Quantity`.")
65+
Base.zero(::Type{<:Quantity}) = error("Cannot create an additive identity for a `Quantity` type, as the dimensions are unknown. Please use `zero(::Quantity)` instead.")
66+
Base.zero(::Type{<:Dimensions}) = error("There is no such thing as an additive identity for a `Dimensions` type, as + is only defined for `Quantity`.")
67+
68+
# Dimensionful 1:
69+
Base.oneunit(q::Quantity) = Quantity(oneunit(ustrip(q)), dimension(q))
70+
Base.oneunit(::Dimensions) = error("There is no such thing as a dimensionful 1 for a `Dimensions` object, as + is only defined for `Quantity`.")
71+
Base.oneunit(::Type{<:Quantity}) = error("Cannot create a dimensionful 1 for a `Quantity` type without knowing the dimensions. Please use `oneunit(::Quantity)` instead.")
72+
Base.oneunit(::Type{<:Dimensions}) = error("There is no such thing as a dimensionful 1 for a `Dimensions` type, as + is only defined for `Quantity`.")
5873

5974
Base.show(io::IO, d::Dimensions) =
6075
let tmp_io = IOBuffer()
@@ -101,15 +116,17 @@ Base.convert(::Type{Dimensions{R}}, d::Dimensions) where {R} = Dimensions{R}(d)
101116
Remove the units from a quantity.
102117
"""
103118
ustrip(q::Quantity) = q.value
104-
ustrip(q::Number) = q
119+
ustrip(::Dimensions) = error("Cannot remove units from a `Dimensions` object.")
120+
ustrip(q) = q
105121

106122
"""
107123
dimension(q::Quantity)
108124
109125
Get the dimensions of a quantity, returning a `Dimensions` object.
110126
"""
111127
dimension(q::Quantity) = q.dimensions
112-
dimension(::Number) = Dimensions()
128+
dimension(d::Dimensions) = d
129+
dimension(_) = Dimensions()
113130

114131
"""
115132
ulength(q::Quantity)

test/unittests.jl

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ using Test
9898
@test uluminosity(y) == R(0)
9999
@test uamount(y) == R(0)
100100
@test ustrip(y) T(0.2^2.1)
101+
102+
dimensionless = Quantity(one(T), R)
103+
y = T(2) + dimensionless
104+
@test ustrip(y) == T(3)
105+
@test dimension(y) == Dimensions(R)
106+
@test typeof(y) == Quantity{T,R}
107+
108+
y = T(2) - dimensionless
109+
@test ustrip(y) == T(1)
110+
@test dimension(y) == Dimensions(R)
111+
@test typeof(y) == Quantity{T,R}
112+
113+
y = dimensionless + T(2)
114+
@test ustrip(y) == T(3)
115+
y = dimensionless - T(2)
116+
@test ustrip(y) == T(-1)
117+
118+
@test_throws DimensionError Quantity(one(T), R, length=1) + 1.0
119+
@test_throws DimensionError Quantity(one(T), R, length=1) - 1.0
120+
@test_throws DimensionError 1.0 + Quantity(one(T), R, length=1)
121+
@test_throws DimensionError 1.0 - Quantity(one(T), R, length=1)
101122
end
102123

103124
x = Quantity(-1.2, length=2 // 5)
@@ -108,7 +129,12 @@ end
108129

109130
@testset "Fallbacks" begin
110131
@test ustrip(0.5) == 0.5
132+
@test ustrip(ones(32)) == ones(32)
111133
@test dimension(0.5) == Dimensions()
134+
@test dimension(ones(32)) == Dimensions()
135+
@test dimension(Dimensions()) === Dimensions()
136+
137+
@test_throws ErrorException ustrip(Dimensions())
112138
end
113139

114140
@testset "Arrays" begin
@@ -126,6 +152,14 @@ end
126152

127153
uX = X .* Quantity(2, length=2.5, luminosity=0.5)
128154
@test sum(X) == 0.5 * ustrip(sum(uX))
155+
156+
x = Quantity(ones(T, 32))
157+
@test ustrip(x + ones(T, 32))[32] == 2
158+
@test typeof(x + ones(T, 32)) <: Quantity{Vector{T}}
159+
@test typeof(x - ones(T, 32)) <: Quantity{Vector{T}}
160+
@test typeof(ones(T, 32) * Dimensions(length=1)) <: Quantity{Vector{T}}
161+
@test typeof(ones(T, 32) / Dimensions(length=1)) <: Quantity{Vector{T}}
162+
@test ones(T, 32) / Dimensions(length=1) == Quantity(ones(T, 32), length=-1)
129163
end
130164
end
131165

@@ -150,25 +184,55 @@ end
150184

151185
@test Dimensions{Int8}([0 for i=1:length(DIMENSION_NAMES)]...) == Dimensions{Int8}()
152186

153-
@test zero(Quantity{ComplexF64,Int8}) + Quantity(1) == Quantity(1.0+0.0im, length=Int8(0))
154-
@test one(Quantity{ComplexF64,Int8}) - Quantity(1) == Quantity(0.0+0.0im, length=Int8(0))
187+
@test zero(Quantity(0.0+0.0im)) + Quantity(1) == Quantity(1.0+0.0im, length=Int8(0))
188+
@test oneunit(Quantity(0.0+0.0im)) - Quantity(1) == Quantity(0.0+0.0im, length=Int8(0))
155189
@test typeof(one(Dimensions{Int16})) == Dimensions{Int16}
156190
@test one(Dimensions{Int16}) == Dimensions(mass=Int16(0))
157191

158-
@test zero(Quantity{ComplexF64}) == Quantity(0.0+0.0im)
192+
@test zero(Quantity(0.0im)) == Quantity(0.0+0.0im)
159193
@test one(Quantity{ComplexF64}) == Quantity(1.0+0.0im)
160194

161-
@test zero(Quantity) == Quantity(0.0)
162-
@test typeof(zero(Quantity)) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
163-
@test one(Quantity) - Quantity(1) == Quantity(0.0)
164-
@test typeof(one(Quantity)) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
165-
@test typeof(one(Dimensions)) == Dimensions{DEFAULT_DIM_TYPE}
195+
@test zero(Quantity(0.0)) == Quantity(0.0)
196+
@test typeof(zero(Quantity(0.0))) == Quantity{Float64,DEFAULT_DIM_TYPE}
197+
@test oneunit(Quantity(1.0)) - Quantity(1.0) == Quantity(0.0)
198+
@test typeof(one(Quantity(1.0))) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
166199
@test one(Dimensions) == Dimensions()
200+
@test one(Dimensions()) == Dimensions()
201+
@test typeof(one(Quantity)) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
202+
@test ustrip(one(Quantity)) === one(DEFAULT_VALUE_TYPE)
203+
@test typeof(one(Quantity(ones(32, 32)))) == Quantity{Matrix{Float64},DEFAULT_DIM_TYPE}
204+
@test dimension(one(Quantity(ones(32, 32), length=1))) == Dimensions()
205+
206+
x = Quantity(1, length=1)
207+
208+
@test zero(x) == Quantity(0, length=1)
209+
@test typeof(zero(x)) == Quantity{Int64,DEFAULT_DIM_TYPE}
210+
211+
# Invalid calls:
212+
@test_throws ErrorException zero(Quantity)
213+
@test_throws ErrorException zero(Dimensions())
214+
@test_throws ErrorException zero(Dimensions)
215+
@test_throws ErrorException oneunit(Quantity)
216+
@test_throws ErrorException oneunit(Dimensions())
217+
@test_throws ErrorException oneunit(Dimensions)
167218

168219
@test sqrt(z * -1) == Quantity(sqrt(52), length=1 // 2, mass=1)
169220
@test cbrt(z) == Quantity(cbrt(-52), length=1 // 3, mass=2 // 3)
170221

171222
@test 1.0 * (Dimensions(length=3)^2) == Quantity(1.0, length=6)
223+
224+
x = 0.9u"km/s"
225+
y = 0.3 * x
226+
@test x > y
227+
@test y < x
228+
229+
x = Quantity(1.0)
230+
231+
@test x == 1.0
232+
@test x >= 1.0
233+
@test x < 2.0
234+
235+
@test_throws DimensionError x < 1.0u"m"
172236
end
173237

174238
@testset "Manual construction" begin

0 commit comments

Comments
 (0)