Skip to content

Commit c84ff9a

Browse files
committed
Backport partially "Optimize multiplication for Normed (JuliaMath#213)"
This adds `wrapping_mul` and `checked_mul` binary operations for `Normed`. This replaces most of Normed's implementation of multiplication with integer operations. This improves the speed in many cases and the accuracy in some cases.
1 parent f5f333d commit c84ff9a

File tree

4 files changed

+76
-3
lines changed

4 files changed

+76
-3
lines changed

src/FixedPointNumbers.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ floattype(::Type{Base.TwicePrecision{T}}) where T<:Union{Float16,Float32} = wide
180180

181181
float(x::FixedPoint) = convert(floattype(x), x)
182182

183+
wrapping_mul(x::X, y::X) where {X <: FixedPoint} = (float(x) * float(y)) % X
184+
*(x::X, y::X) where {X <: FixedPoint} = wrapping_mul(x, y)
185+
183186
function minmax(x::X, y::X) where {X <: FixedPoint}
184187
a, b = minmax(reinterpret(x), reinterpret(y))
185188
X(a,0), X(b,0)

src/normed.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ function rem(x::Float64, ::Type{N}) where {f, N <: Normed{UInt64,f}}
127127
reinterpret(N, r << UInt8(f - 53) - unsigned(signed(r) >> 0x35))
128128
end
129129

130-
131130
function (::Type{T})(x::Normed) where {T <: AbstractFloat}
132131
# The following optimization for constant division may cause rounding errors.
133132
# y = reinterpret(x)*(one(rawtype(x))/convert(T, rawone(x)))
@@ -248,8 +247,37 @@ Base.BigFloat(x::Normed) = reinterpret(x) / BigFloat(rawone(x))
248247

249248
Base.Rational(x::Normed) = reinterpret(x)//rawone(x)
250249

251-
# unchecked arithmetic
252-
*(x::T, y::T) where {T <: Normed} = convert(T,convert(floattype(T), x)*convert(floattype(T), y))
250+
# Division by `2^f-1` with RoundNearest. The result would be in the lower half bits.
251+
div_2fm1(x::T, ::Val{f}) where {T, f} = (x + (T(1)<<(f - 1) - 0x1)) ÷ (T(1) << f - 0x1)
252+
div_2fm1(x::T, ::Val{1}) where T = x
253+
div_2fm1(x::UInt16, ::Val{8}) = (((x + 0x80) >> 0x8) + x + 0x80) >> 0x8
254+
div_2fm1(x::UInt32, ::Val{16}) = (((x + 0x8000) >> 0x10) + x + 0x8000) >> 0x10
255+
div_2fm1(x::UInt64, ::Val{32}) = (((x + 0x80000000) >> 0x20) + x + 0x80000000) >> 0x20
256+
div_2fm1(x::UInt128, ::Val{64}) = (((x + 0x8000000000000000) >> 0x40) + x + 0x8000000000000000) >> 0x40
257+
258+
# wrapping arithmetic
259+
function wrapping_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
260+
z = widemul(x.i, y.i)
261+
N(div_2fm1(z, Val(Int(f))) % T, 0)
262+
end
263+
264+
# checked arithmetic
265+
function checked_mul(x::N, y::N) where {N <: Normed}
266+
z = float(x) * float(y)
267+
z < typemax(N) + eps(N)/2 || throw_overflowerror(:*, x, y)
268+
z % N
269+
end
270+
function checked_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
271+
f == bitwidth(T) && return wrapping_mul(x, y)
272+
z = widemul(x.i, y.i)
273+
m = widemul(typemax(N).i, rawone(N)) + (rawone(N) >> 0x1)
274+
z < m || throw_overflowerror(:*, x, y)
275+
N(div_2fm1(z, Val(Int(f))) % T, 0)
276+
end
277+
278+
# Override the default arithmetic with `checked` for backward compatibility
279+
*(x::N, y::N) where {N <: Normed} = checked_mul(x, y)
280+
253281
/(x::T, y::T) where {T <: Normed} = convert(T,convert(floattype(T), x)/convert(floattype(T), y))
254282

255283
# Functions

test/fixed.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,28 @@ end
256256
@test (-67.2 % T).i == round(Int, -67.2*512) % Int16
257257
end
258258

259+
@testset "mul" begin
260+
wrapping_mul(x::F, y::F) where {F <: Fixed} = x * y
261+
for F in target(Fixed; ex = :thin)
262+
@test wrapping_mul(typemax(F), zero(F)) === zero(F)
263+
264+
# FIXME: Both the rhs and lhs of the following tests may be inaccurate due to `rem`
265+
F === Fixed{Int128,127} && continue
266+
267+
@test wrapping_mul(F(-1), typemax(F)) === -typemax(F)
268+
269+
@test wrapping_mul(typemin(F), typemax(F)) === big(typemin(F)) * big(typemax(F)) % F
270+
271+
@test wrapping_mul(typemin(F), typemin(F)) === big(typemin(F))^2 % F
272+
end
273+
for F in target(Fixed, :i8; ex = :thin)
274+
xs = typemin(F):eps(F):typemax(F)
275+
xys = ((x, y) for x in xs, y in xs)
276+
fmul(x, y) = float(x) * float(y) # note that precision(Float32) < 32
277+
@test all(((x, y),) -> wrapping_mul(x, y) === fmul(x, y) % F, xys)
278+
end
279+
end
280+
259281
@testset "rounding" begin
260282
for sym in (:i8, :i16, :i32, :i64)
261283
T = symbol_to_inttype(Fixed, sym)

test/normed.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,26 @@ end
283283
end
284284
end
285285

286+
@testset "mul" begin
287+
checked_mul = FixedPointNumbers.checked_mul
288+
for N in target(Normed; ex = :thin)
289+
@test checked_mul(typemax(N), zero(N)) === zero(N)
290+
291+
@test checked_mul(one(N), typemax(N)) === typemax(N)
292+
293+
if typemax(N) != 1
294+
@test_throws OverflowError checked_mul(typemax(N), typemax(N))
295+
end
296+
end
297+
for N in target(Normed, :i8; ex = :thin)
298+
xs = typemin(N):eps(N):typemax(N)
299+
xys = ((x, y) for x in xs, y in xs)
300+
fmul(x, y) = float(x) * float(y) # note that precision(Float32) < 32
301+
@test all(((x, y),) -> !(typemin(N) <= fmul(x, y) <= typemax(N)) ||
302+
(fmul(x, y) % N) === checked_mul(x, y), xys)
303+
end
304+
end
305+
286306
@testset "div/fld1" begin
287307
@test div(reinterpret(N0f8, 0x10), reinterpret(N0f8, 0x02)) == fld(reinterpret(N0f8, 0x10), reinterpret(N0f8, 0x02)) == 8
288308
@test div(reinterpret(N0f8, 0x0f), reinterpret(N0f8, 0x02)) == fld(reinterpret(N0f8, 0x0f), reinterpret(N0f8, 0x02)) == 7

0 commit comments

Comments
 (0)