@@ -72,10 +72,11 @@ StatsBase.sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = Stats
72
72
73
73
function StatsBase. sample (s:: BatchSampler , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} , names)
74
74
t = e. traces
75
- st = deepcopy (t. priorities)
76
- st .*= e. sampleable_inds[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
77
- inds, priorities = rand (s. rng, st, s. batch_size)
78
- NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x)][inds]), names)... ))
75
+ p = collect (deepcopy (t. priorities))
76
+ w = StatsBase. FrequencyWeights (p)
77
+ w .*= e. sampleable_inds[1 : end - 1 ]
78
+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size)
79
+ NamedTuple {(:key, :priority, names...)} ((t. keys[inds], p[inds], map (x -> collect (t. traces[Val (x)][inds]), names)... ))
79
80
end
80
81
81
82
function StatsBase. sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
@@ -242,12 +243,13 @@ fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds]
242
243
243
244
function StatsBase. sample (s:: NStepBatchSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
244
245
t = e. traces
245
- st = deepcopy (t. priorities)
246
+ p = collect (deepcopy (t. priorities))
247
+ w = StatsBase. FrequencyWeights (p)
246
248
valids, ns = valid_range (s,e)
247
- st .*= valids[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
248
- inds, priorities = rand (s. rng, st , s. batch_size)
249
+ w .*= valids[1 : end - 1 ]
250
+ inds = StatsBase . sample (s. rng, eachindex (w), w , s. batch_size)
249
251
merge (
250
- (key= t. keys[inds], priority= priorities ),
252
+ (key= t. keys[inds], priority= p[inds] ),
251
253
fetch (s, e, Val (names), inds, ns)
252
254
)
253
255
end
0 commit comments