Skip to content

Commit a9c5b5c

Browse files
committed
use StatsBase weights
1 parent 2595265 commit a9c5b5c

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/samplers.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ StatsBase.sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = Stats
7272

7373
function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}, names)
7474
t = e.traces
75-
st = deepcopy(t.priorities)
76-
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
77-
inds, priorities = rand(s.rng, st, s.batch_size)
78-
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
75+
p = collect(deepcopy(t.priorities))
76+
w = StatsBase.FrequencyWeights(p)
77+
w .*= e.sampleable_inds[1:end-1]
78+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
79+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
7980
end
8081

8182
function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
@@ -242,12 +243,13 @@ fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds]
242243

243244
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
244245
t = e.traces
245-
st = deepcopy(t.priorities)
246+
p = collect(deepcopy(t.priorities))
247+
w = StatsBase.FrequencyWeights(p)
246248
valids, ns = valid_range(s,e)
247-
st .*= valids[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
248-
inds, priorities = rand(s.rng, st, s.batch_size)
249+
w .*= valids[1:end-1]
250+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
249251
merge(
250-
(key=t.keys[inds], priority=priorities),
252+
(key=t.keys[inds], priority=p[inds]),
251253
fetch(s, e, Val(names), inds, ns)
252254
)
253255
end

0 commit comments

Comments
 (0)