Skip to content

Commit f8771e2

Browse files
committed
Ensure dimensions are always copied
1 parent 30ac1e8 commit f8771e2

File tree

3 files changed

+31
-30
lines changed

3 files changed

+31
-30
lines changed

src/arrays.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ end
104104
function Base.getindex(A::QuantityArray, i...)
105105
output_value = getindex(ustrip(A), i...)
106106
if isa(output_value, AbstractArray)
107-
return QuantityArray(output_value, dimension(A), quantity_type(A))
107+
return QuantityArray(output_value, copy(dimension(A)), quantity_type(A))
108108
else
109-
return new_quantity(quantity_type(A), output_value, dimension(A))
109+
return new_quantity(quantity_type(A), output_value, copy(dimension(A)))
110110
end
111111
end
112112
function Base.setindex!(A::QuantityArray{T,N,D,Q}, v::Q, i...) where {T,N,D,Q<:AbstractQuantity}
@@ -122,14 +122,14 @@ unsafe_setindex!(A, v, i...) = setindex!(ustrip(A), ustrip(v), i...)
122122
Base.IndexStyle(::Type{Q}) where {Q<:QuantityArray} = IndexStyle(array_type(Q))
123123

124124

125-
Base.similar(A::QuantityArray) = QuantityArray(similar(ustrip(A)), dimension(A), quantity_type(A))
126-
Base.similar(A::QuantityArray, ::Type{S}) where {S} = QuantityArray(similar(ustrip(A), S), dimension(A), quantity_type(A))
125+
Base.similar(A::QuantityArray) = QuantityArray(similar(ustrip(A)), copy(dimension(A)), quantity_type(A))
126+
Base.similar(A::QuantityArray, ::Type{S}) where {S} = QuantityArray(similar(ustrip(A), S), copy(dimension(A)), quantity_type(A))
127127

128128
# Unfortunately this mess of `similar` is required to avoid ambiguous methods.
129129
# c.f. base/abstractarray.jl
130130
for dim_type in (:(Dims), :(Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}), :(Tuple{Integer, Vararg{Integer}}))
131-
@eval Base.similar(A::QuantityArray, dims::$dim_type) = QuantityArray(similar(ustrip(A), dims), dimension(A), quantity_type(A))
132-
@eval Base.similar(A::QuantityArray, ::Type{S}, dims::$dim_type) where {S} = QuantityArray(similar(ustrip(A), S, dims), dimension(A), quantity_type(A))
131+
@eval Base.similar(A::QuantityArray, dims::$dim_type) = QuantityArray(similar(ustrip(A), dims), copy(dimension(A)), quantity_type(A))
132+
@eval Base.similar(A::QuantityArray, ::Type{S}, dims::$dim_type) where {S} = QuantityArray(similar(ustrip(A), S, dims), copy(dimension(A)), quantity_type(A))
133133
end
134134

135135
Base.BroadcastStyle(::Type{QA}) where {QA<:QuantityArray} = Broadcast.ArrayStyle{QA}()
@@ -138,7 +138,7 @@ function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Typ
138138
T = value_type(ElType)
139139
output_array = similar(bc, T)
140140
first_output::ElType = materialize_first(bc)
141-
return QuantityArray(output_array, dimension(first_output)::dim_type(ElType), ElType)
141+
return QuantityArray(output_array, copy(dimension(first_output))::dim_type(ElType), ElType)
142142
end
143143
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QuantityArray{T,N,D,Q,V}}}, ::Type{ElType}) where {T,N,D,Q,V<:Array{T,N},ElType}
144144
return similar(Array{ElType}, axes(bc))
@@ -187,7 +187,7 @@ for f in (:cat, :hcat, :vcat)
187187
preamble = quote
188188
allequal(dimension.(A)) || throw(DimensionError(A[begin], A[begin+1:end]))
189189
A = promote(A...)
190-
dimensions = dimension(A[begin])
190+
dimensions = copy(dimension(A[begin]))
191191
Q = quantity_type(A[begin])
192192
end
193193
if f == :cat
@@ -202,8 +202,8 @@ for f in (:cat, :hcat, :vcat)
202202
end
203203
end
204204
end
205-
Base.fill(x::AbstractQuantity, dims::Dims...) = QuantityArray(fill(ustrip(x), dims...), dimension(x), typeof(x))
206-
Base.fill(x::AbstractQuantity, t::Tuple{}) = QuantityArray(fill(ustrip(x), t), dimension(x), typeof(x))
205+
Base.fill(x::AbstractQuantity, dims::Dims...) = QuantityArray(fill(ustrip(x), dims...), copy(dimension(x)), typeof(x))
206+
Base.fill(x::AbstractQuantity, t::Tuple{}) = QuantityArray(fill(ustrip(x), t), copy(dimension(x)), typeof(x))
207207

208208
ulength(q::QuantityArray) = ulength(dimension(q))
209209
umass(q::QuantityArray) = umass(dimension(q))

src/math.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,37 @@ Base.:*(l::AbstractDimensions, r::AbstractDimensions) = map_dimensions(+, l, r)
22
Base.:*(l::AbstractQuantity, r::AbstractQuantity) = new_quantity(typeof(l), ustrip(l) * ustrip(r), dimension(l) * dimension(r))
33
Base.:*(l::AbstractQuantity, r::AbstractDimensions) = new_quantity(typeof(l), ustrip(l), dimension(l) * r)
44
Base.:*(l::AbstractDimensions, r::AbstractQuantity) = new_quantity(typeof(r), ustrip(r), l * dimension(r))
5-
Base.:*(l::AbstractQuantity, r) = new_quantity(typeof(l), ustrip(l) * r, dimension(l))
6-
Base.:*(l, r::AbstractQuantity) = new_quantity(typeof(r), l * ustrip(r), dimension(r))
5+
Base.:*(l::AbstractQuantity, r) = new_quantity(typeof(l), ustrip(l) * r, copy(dimension(l)))
6+
Base.:*(l, r::AbstractQuantity) = new_quantity(typeof(r), l * ustrip(r), copy(dimension(r)))
77
Base.:*(l::AbstractDimensions, r) = error("Please use an `AbstractQuantity` for multiplication. You used multiplication on types: $(typeof(l)) and $(typeof(r)).")
88
Base.:*(l, r::AbstractDimensions) = error("Please use an `AbstractQuantity` for multiplication. You used multiplication on types: $(typeof(l)) and $(typeof(r)).")
99

1010
Base.:/(l::AbstractDimensions, r::AbstractDimensions) = map_dimensions(-, l, r)
1111
Base.:/(l::AbstractQuantity, r::AbstractQuantity) = new_quantity(typeof(l), ustrip(l) / ustrip(r), dimension(l) / dimension(r))
1212
Base.:/(l::AbstractQuantity, r::AbstractDimensions) = new_quantity(typeof(l), ustrip(l), dimension(l) / r)
1313
Base.:/(l::AbstractDimensions, r::AbstractQuantity) = new_quantity(typeof(r), inv(ustrip(r)), l / dimension(r))
14-
Base.:/(l::AbstractQuantity, r) = new_quantity(typeof(l), ustrip(l) / r, dimension(l))
14+
Base.:/(l::AbstractQuantity, r) = new_quantity(typeof(l), ustrip(l) / r, copy(dimension(l)))
1515
Base.:/(l, r::AbstractQuantity) = l * inv(r)
1616
Base.:/(l::AbstractDimensions, r) = error("Please use an `AbstractQuantity` for division. You used division on types: $(typeof(l)) and $(typeof(r)).")
1717
Base.:/(l, r::AbstractDimensions) = error("Please use an `AbstractQuantity` for division. You used division on types: $(typeof(l)) and $(typeof(r)).")
1818

1919
Base.:+(l::AbstractQuantity, r::AbstractQuantity) =
2020
let
2121
dimension(l) == dimension(r) || throw(DimensionError(l, r))
22-
new_quantity(typeof(l), ustrip(l) + ustrip(r), dimension(l))
22+
new_quantity(typeof(l), ustrip(l) + ustrip(r), copy(dimension(l)))
2323
end
24-
Base.:-(l::AbstractQuantity) = new_quantity(typeof(l), -ustrip(l), dimension(l))
24+
Base.:-(l::AbstractQuantity) = new_quantity(typeof(l), -ustrip(l), copy(dimension(l)))
2525
Base.:-(l::AbstractQuantity, r::AbstractQuantity) = l + (-r)
2626

2727
Base.:+(l::AbstractQuantity, r) =
2828
let
2929
iszero(dimension(l)) || throw(DimensionError(l, r))
30-
new_quantity(typeof(l), ustrip(l) + r, dimension(l))
30+
new_quantity(typeof(l), ustrip(l) + r, copy(dimension(l)))
3131
end
3232
Base.:+(l, r::AbstractQuantity) =
3333
let
3434
iszero(dimension(r)) || throw(DimensionError(l, r))
35-
new_quantity(typeof(r), l + ustrip(r), dimension(r))
35+
new_quantity(typeof(r), l + ustrip(r), copy(dimension(r)))
3636
end
3737
Base.:-(l::AbstractQuantity, r) = l + (-r)
3838
Base.:-(l, r::AbstractQuantity) = l + (-r)
@@ -77,6 +77,6 @@ Base.sqrt(q::AbstractQuantity) = new_quantity(typeof(q), sqrt(ustrip(q)), sqrt(d
7777
Base.cbrt(d::AbstractDimensions{R}) where {R} = d^inv(convert(R, 3))
7878
Base.cbrt(q::AbstractQuantity) = new_quantity(typeof(q), cbrt(ustrip(q)), cbrt(dimension(q)))
7979

80-
Base.abs(q::AbstractQuantity) = new_quantity(typeof(q), abs(ustrip(q)), dimension(q))
80+
Base.abs(q::AbstractQuantity) = new_quantity(typeof(q), abs(ustrip(q)), copy(dimension(q)))
8181
Base.abs2(q::AbstractQuantity) = new_quantity(typeof(q), abs2(ustrip(q)), dimension(q)^2)
8282
Base.angle(q::AbstractQuantity{T}) where {T<:Complex} = angle(ustrip(q))

src/utils.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
return output
2727
end
2828

29-
Base.float(q::AbstractQuantity) = new_quantity(typeof(q), float(ustrip(q)), dimension(q))
29+
Base.float(q::AbstractQuantity) = new_quantity(typeof(q), float(ustrip(q)), copy(dimension(q)))
3030
Base.convert(::Type{T}, q::AbstractQuantity) where {T<:Real} =
3131
let
3232
@assert iszero(dimension(q)) "$(typeof(q)): $(q) has dimensions! Use `ustrip` instead."
@@ -45,12 +45,12 @@ Base.axes(q::AbstractQuantity) = axes(ustrip(q))
4545
Base.iterate(qd::AbstractQuantity, maybe_state...) =
4646
let subiterate=iterate(ustrip(qd), maybe_state...)
4747
subiterate === nothing && return nothing
48-
return new_quantity(typeof(qd), subiterate[1], dimension(qd)), subiterate[2]
48+
return new_quantity(typeof(qd), subiterate[1], copy(dimension(qd))), subiterate[2]
4949
end
5050
Base.ndims(::Type{<:AbstractQuantity{T}}) where {T} = ndims(T)
5151
Base.ndims(q::AbstractQuantity) = ndims(ustrip(q))
52-
Base.broadcastable(q::AbstractQuantity) = new_quantity(typeof(q), Base.broadcastable(ustrip(q)), dimension(q))
53-
Base.getindex(q::AbstractQuantity, i...) = new_quantity(typeof(q), getindex(ustrip(q), i...), dimension(q))
52+
Base.broadcastable(q::AbstractQuantity) = new_quantity(typeof(q), Base.broadcastable(ustrip(q)), copy(dimension(q)))
53+
Base.getindex(q::AbstractQuantity, i...) = new_quantity(typeof(q), getindex(ustrip(q), i...), copy(dimension(q)))
5454
Base.keys(q::AbstractQuantity) = keys(ustrip(q))
5555

5656

@@ -103,7 +103,7 @@ end
103103

104104
# Simple operations which return a full quantity (same dimensions)
105105
for f in (:real, :imag, :conj, :adjoint, :unsigned, :nextfloat, :prevfloat)
106-
@eval Base.$f(q::AbstractQuantity) = new_quantity(typeof(q), $f(ustrip(q)), dimension(q))
106+
@eval Base.$f(q::AbstractQuantity) = new_quantity(typeof(q), $f(ustrip(q)), copy(dimension(q)))
107107
end
108108

109109
# Base.one, typemin, typemax
@@ -116,20 +116,20 @@ for f in (:one, :typemin, :typemax)
116116
if f == :one # Return empty dimensions, as should be multiplicative identity.
117117
@eval Base.$f(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, $f(ustrip(q)), one(dimension(q)))
118118
else
119-
@eval Base.$f(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, $f(ustrip(q)), dimension(q))
119+
@eval Base.$f(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, $f(ustrip(q)), copy(dimension(q)))
120120
end
121121
end
122122
Base.one(::Type{D}) where {D<:AbstractDimensions} = D()
123123
Base.one(::D) where {D<:AbstractDimensions} = one(D)
124124

125125
# Additive identities (zero)
126-
Base.zero(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, zero(ustrip(q)), dimension(q))
126+
Base.zero(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, zero(ustrip(q)), copy(dimension(q)))
127127
Base.zero(::AbstractDimensions) = error("There is no such thing as an additive identity for a `AbstractDimensions` object, as + is only defined for `AbstractQuantity`.")
128128
Base.zero(::Type{<:AbstractQuantity}) = error("Cannot create an additive identity for a `AbstractQuantity` type, as the dimensions are unknown. Please use `zero(::AbstractQuantity)` instead.")
129129
Base.zero(::Type{<:AbstractDimensions}) = error("There is no such thing as an additive identity for a `AbstractDimensions` type, as + is only defined for `AbstractQuantity`.")
130130

131131
# Dimensionful 1 (oneunit)
132-
Base.oneunit(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, oneunit(ustrip(q)), dimension(q))
132+
Base.oneunit(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, oneunit(ustrip(q)), copy(dimension(q)))
133133
Base.oneunit(::AbstractDimensions) = error("There is no such thing as a dimensionful 1 for a `AbstractDimensions` object, as + is only defined for `AbstractQuantity`.")
134134
Base.oneunit(::Type{<:AbstractQuantity}) = error("Cannot create a dimensionful 1 for a `AbstractQuantity` type without knowing the dimensions. Please use `oneunit(::AbstractQuantity)` instead.")
135135
Base.oneunit(::Type{<:AbstractDimensions}) = error("There is no such thing as a dimensionful 1 for a `AbstractDimensions` type, as + is only defined for `AbstractQuantity`.")
@@ -177,14 +177,15 @@ tryrationalize(::Type{R}, x) where {R} = isinteger(x) ? convert(R, round(Int, x)
177177
Base.showerror(io::IO, e::DimensionError) = print(io, "DimensionError: ", e.q1, " and ", e.q2, " have incompatible dimensions")
178178

179179
Base.convert(::Type{Q}, q::AbstractQuantity) where {Q<:AbstractQuantity} = q
180-
Base.convert(::Type{Q}, q::AbstractQuantity) where {T,Q<:AbstractQuantity{T}} = new_quantity(Q, convert(T, ustrip(q)), dimension(q))
181-
Base.convert(::Type{Q}, q::AbstractQuantity) where {T,D,Q<:AbstractQuantity{T,D}} = new_quantity(Q, convert(T, ustrip(q)), convert(D, dimension(q)))
180+
Base.convert(::Type{Q}, q::AbstractQuantity) where {T,Q<:AbstractQuantity{T}} = new_quantity(Q, convert(T, ustrip(q)), copy(dimension(q)))
181+
Base.convert(::Type{Q}, q::AbstractQuantity) where {T,D,Q<:AbstractQuantity{T,D}} = new_quantity(Q, convert(T, ustrip(q)), convert(D, copy(dimension(q))))
182182

183183
Base.convert(::Type{D}, d::AbstractDimensions) where {D<:AbstractDimensions} = d
184184
Base.convert(::Type{D}, d::AbstractDimensions) where {R,D<:AbstractDimensions{R}} = D(d)
185185

186-
Base.copy(d::D) where {D<:AbstractDimensions} = map_dimensions(copy, d)
187-
Base.copy(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, copy(ustrip(q)), copy(dimension(q)))
186+
@inline Base.copy(d::D) where {D<:AbstractDimensions} = map_dimensions(copy, d)
187+
@inline Base.copy(d::D) where {D<:Dimensions} = d
188+
@inline Base.copy(q::Q) where {Q<:AbstractQuantity} = new_quantity(Q, copy(ustrip(q)), copy(dimension(q)))
188189

189190
"""
190191
ustrip(q::AbstractQuantity)

0 commit comments

Comments
 (0)