@@ -209,8 +209,21 @@ function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
209
209
NamedTuple {SSLART} (map (collect, (s, s′, l, a, r, t)))
210
210
end
211
211
212
- function StatsBase. sample (s:: NStepBatchSampler{names} , t:: CircularPrioritizedTraces ) where {names}
213
- inds, priorities = rand (s. rng, t. priorities, s. batch_size)
212
+ function StatsBase. sample (s:: NStepBatchSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} )
213
+ t = e. traces
214
+ st = deepcopy (t. priorities)
215
+ st .*= e. sampleable_inds[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
216
+ inds, priorities = rand (s. rng, st, s. batch_size)
217
+
218
+ NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[x][inds]), names)... ))
219
+ end
220
+
221
+ function StatsBase. sample (s:: NStepBatchSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
222
+ t = e. traces
223
+ st = deepcopy (t. priorities)
224
+ st .*= e. sampleable_inds[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
225
+ inds, priorities = rand (s. rng, st, s. batch_size)
226
+
214
227
merge (
215
228
(key= t. keys[inds], priority= priorities),
216
229
StatsBase. sample (s, t. traces, Val (names), inds)
0 commit comments