Skip to content

Commit 3e41a6a

Browse files
authored
Use sampler-based Random API (#206)
This prevents `rand` from returning a `ReinterpretArray` to avoid the performance problem with `ReinterpretArray` . This also supports specifying the RNG option.
1 parent ae6b911 commit 3e41a6a

File tree

5 files changed

+16
-5
lines changed

5 files changed

+16
-5
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
33
version = "0.8.4"
44

55
[deps]
6+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
67
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
78

89
[compat]

src/FixedPointNumbers.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ import Base: ==, <, <=, -, +, *, /, ~, isapprox,
77
big, rationalize, float, trunc, round, floor, ceil, bswap, clamp,
88
div, fld, rem, mod, mod1, fld1, min, max, minmax,
99
signed, unsigned, copysign, flipsign, signbit,
10-
rand, length
10+
length
1111

1212
import Statistics # for _mean_promote
13+
import Random: Random, AbstractRNG, SamplerType, rand!
1314

1415
using Base.Checked: checked_add, checked_sub, checked_div
1516

@@ -315,7 +316,7 @@ const UF = (N0f8, N6f10, N4f12, N2f14, N0f16)
315316
promote_rule(::Type{X}, ::Type{Tf}) where {X <: FixedPoint, Tf <: AbstractFloat} =
316317
promote_type(floattype(X), Tf)
317318

318-
# Note that `Tr` does not always have enough domains.
319+
# Note that `Tr` does not always have enough domains.
319320
promote_rule(::Type{X}, ::Type{Tr}) where {X <: FixedPoint, Tr <: Rational} = Tr
320321

321322
promote_rule(::Type{X}, ::Type{Ti}) where {X <: FixedPoint, Ti <: Integer} = floattype(X)
@@ -382,8 +383,15 @@ scaledual(::Type{Tdual}, x::AbstractArray{T}) where {Tdual, T <: FixedPoint} =
382383
throw(ArgumentError(String(take!(io))))
383384
end
384385

385-
rand(::Type{T}) where {T <: FixedPoint} = reinterpret(T, rand(rawtype(T)))
386-
rand(::Type{T}, sz::Dims) where {T <: FixedPoint} = reinterpret(T, rand(rawtype(T), sz))
386+
function Random.rand(r::AbstractRNG, ::SamplerType{X}) where X <: FixedPoint
387+
X(rand(r, rawtype(X)), 0)
388+
end
389+
390+
function rand!(r::AbstractRNG, A::Array{X}, ::SamplerType{X}) where {T, X <: FixedPoint{T}}
391+
At = unsafe_wrap(Array, reinterpret(Ptr{T}, pointer(A)), size(A))
392+
Random.rand!(r, At, SamplerType{T}())
393+
A
394+
end
387395

388396
if VERSION >= v"1.1" # work around https://github.com/JuliaLang/julia/issues/34121
389397
include("precompile.jl")

test/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FixedPointNumbers, Statistics, Test
1+
using FixedPointNumbers, Statistics, Random, Test
22
using FixedPointNumbers: bitwidth, rawtype, nbitsfrac
33

44
"""

test/fixed.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ end
436436
@test ndims(a) == 2 && eltype(a) === F
437437
@test size(a) == (3,5)
438438
end
439+
@test rand(MersenneTwister(1234), Q0f7) === -0.156Q0f7
439440
end
440441

441442
@testset "Promotion within Fixed" begin

test/normed.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ end
466466
@test ndims(a) == 2 && eltype(a) === N
467467
@test size(a) == (3,5)
468468
end
469+
@test rand(MersenneTwister(1234), N0f8) === 0.925N0f8
469470
end
470471

471472
@testset "Promotion within Normed" begin

0 commit comments

Comments
 (0)