1
1
using Random
2
- export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
2
+ export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler
3
3
4
4
struct SampleGenerator{S,T}
5
5
sampler:: S
@@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
29
29
export BatchSampler
30
30
31
31
struct BatchSampler{names}
32
- batch_size :: Int
32
+ batchsize :: Int
33
33
rng:: Random.AbstractRNG
34
34
end
35
35
36
36
"""
37
- BatchSampler{names}(;batch_size , rng=Random.GLOBAL_RNG)
38
- BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
37
+ BatchSampler{names}(;batchsize , rng=Random.GLOBAL_RNG)
38
+ BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)
39
39
40
- Uniformly sample **ONE** batch of `batch_size ` examples for each trace specified
40
+ Uniformly sample **ONE** batch of `batchsize ` examples for each trace specified
41
41
in `names`. If `names` is not set, all the traces will be sampled.
42
42
"""
43
- BatchSampler (batch_size ; kw... ) = BatchSampler (; batch_size = batch_size , kw... )
43
+ BatchSampler (batchsize ; kw... ) = BatchSampler (; batchsize = batchsize , kw... )
44
44
BatchSampler (; kw... ) = BatchSampler {nothing} (; kw... )
45
- BatchSampler {names} (batch_size ; kw... ) where {names} = BatchSampler {names} (; batch_size = batch_size , kw... )
46
- BatchSampler {names} (; batch_size , rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batch_size , rng)
45
+ BatchSampler {names} (batchsize ; kw... ) where {names} = BatchSampler {names} (; batchsize = batchsize , kw... )
46
+ BatchSampler {names} (; batchsize , rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batchsize , rng)
47
47
48
48
StatsBase. sample (s:: BatchSampler{nothing} , t:: AbstractTraces ) = StatsBase. sample (s, t, keys (t))
49
49
StatsBase. sample (s:: BatchSampler{names} , t:: AbstractTraces ) where {names} = StatsBase. sample (s, t, names)
50
50
51
51
function StatsBase. sample (s:: BatchSampler , t:: AbstractTraces , names, weights = StatsBase. UnitWeights {Int} (length (t)))
52
- inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batch_size )
52
+ inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batchsize )
53
53
NamedTuple {names} (map (x -> collect (t[Val (x)][inds]), names))
54
54
end
55
55
@@ -75,12 +75,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
75
75
p = collect (deepcopy (t. priorities))
76
76
w = StatsBase. FrequencyWeights (p)
77
77
w .*= e. sampleable_inds[1 : end - 1 ]
78
- inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size )
78
+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize )
79
79
NamedTuple {(:key, :priority, names...)} ((t. keys[inds], p[inds], map (x -> collect (t. traces[Val (x)][inds]), names)... ))
80
80
end
81
81
82
82
function StatsBase. sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
83
- inds, priorities = rand (s. rng, t. priorities, s. batch_size )
83
+ inds, priorities = rand (s. rng, t. priorities, s. batchsize )
84
84
NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x)][inds]), names)... ))
85
85
end
86
86
@@ -165,41 +165,42 @@ end
165
165
export NStepBatchSampler
166
166
167
167
"""
168
- NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)
168
+
169
+ NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
169
170
170
171
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
171
172
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172
173
that in up to `n > 1` steps later in the buffer. The reward will be
173
174
the discounted sum of the `n` rewards, with `γ` as the discount factor.
174
175
175
- NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size ` is set
176
- to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
177
- of partial observability, for example when the state is approximated by `stack_size ` consecutive
176
+ NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize ` is set
177
+ to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
178
+ of partial observability, for example when the state is approximated by `stacksize ` consecutive
178
179
frames.
179
180
"""
180
- mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} }
181
+ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
181
182
n:: Int # !!! n starts from 1
182
183
γ:: Float32
183
- batch_size :: Int
184
- stack_size :: S
185
- rng:: Any
184
+ batchsize :: Int
185
+ stacksize :: S
186
+ rng:: R
186
187
end
187
188
188
189
NStepBatchSampler (t:: AbstractTraces ; kw... ) = NStepBatchSampler {keys(t)} (; kw... )
189
- function NStepBatchSampler {names} (; n, γ, batch_size = 32 , stack_size = nothing , rng= Random. GLOBAL_RNG ) where {names}
190
+ function NStepBatchSampler {names} (; n, γ, batchsize = 32 , stacksize = nothing , rng= Random. default_rng () ) where {names}
190
191
@assert n >= 1 " n must be ≥ 1."
191
- ss = stack_size == 1 ? nothing : stack_size
192
- NStepBatchSampler {names, typeof(ss)} (n, γ, batch_size , ss, rng)
192
+ ss = stacksize == 1 ? nothing : stacksize
193
+ NStepBatchSampler {names, typeof(ss), typeof(rng) } (n, γ, batchsize , ss, rng)
193
194
end
194
195
195
- # return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
196
+ # return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
196
197
function valid_range (s:: NStepBatchSampler , eb:: EpisodesBuffer )
197
198
range = copy (eb. sampleable_inds)
198
199
ns = Vector {Int} (undef, length (eb. sampleable_inds))
199
- stack_size = isnothing (s. stack_size ) ? 1 : s. stack_size
200
+ stacksize = isnothing (s. stacksize ) ? 1 : s. stacksize
200
201
for idx in eachindex (range)
201
202
step_number = eb. step_numbers[idx]
202
- range[idx] = step_number >= stack_size && eb. sampleable_inds[idx]
203
+ range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
203
204
ns[idx] = min (s. n, eb. episodes_lengths[idx] - step_number + 1 )
204
205
end
205
206
return range, ns
@@ -211,19 +212,19 @@ end
211
212
212
213
function StatsBase. sample (s:: NStepBatchSampler , t:: EpisodesBuffer , :: Val{names} ) where names
213
214
weights, ns = valid_range (s, t)
214
- inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batch_size )
215
+ inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batchsize )
215
216
fetch (s, t, Val (names), inds, ns)
216
217
end
217
218
218
219
function fetch (s:: NStepBatchSampler , ts:: EpisodesBuffer , :: Val{names} , inds, ns) where names
219
220
NamedTuple {names} (map (name -> collect (fetch (s, ts[name], Val (name), inds, ns[inds])), names))
220
221
end
221
222
222
- # state and next_state have specialized fetch methods due to stack_size
223
+ # state and next_state have specialized fetch methods due to stacksize
223
224
fetch (:: NStepBatchSampler{names, Nothing} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[inds]
224
- fetch (s:: NStepBatchSampler{names, Int} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[[x + i for i in - s. stack_size + 1 : 0 , x in inds]]
225
+ fetch (s:: NStepBatchSampler{names, Int} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[[x + i for i in - s. stacksize + 1 : 0 , x in inds]]
225
226
fetch (:: NStepBatchSampler{names, Nothing} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[inds .+ ns .- 1 ]
226
- fetch (s:: NStepBatchSampler{names, Int} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in - s. stack_size + 1 : 0 , (idx,x) in enumerate (inds)]]
227
+ fetch (s:: NStepBatchSampler{names, Int} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in - s. stacksize + 1 : 0 , (idx,x) in enumerate (inds)]]
227
228
228
229
# reward due to discounting
229
230
function fetch (s:: NStepBatchSampler , trace:: AbstractTrace , :: Val{:reward} , inds, ns)
@@ -247,7 +248,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247
248
w = StatsBase. FrequencyWeights (p)
248
249
valids, ns = valid_range (s,e)
249
250
w .*= valids[1 : end - 1 ]
250
- inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size )
251
+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize )
251
252
merge (
252
253
(key= t. keys[inds], priority= p[inds]),
253
254
fetch (s, e, Val (names), inds, ns)
@@ -297,3 +298,74 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
297
298
298
299
return [make_episode (t, r, names) for r in ranges]
299
300
end
301
+
302
+ # ####MultiStepSampler
303
+
304
+ """
305
+ MultiStepSampler{names}(batchsize, stacksize, n, rng)
306
+
307
+ Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308
+ `x`. The samples are returned in an array of batchsize elements. For each element, n is
309
+ truncated by the end of its episode. This means that the dimensions of each sample are not
310
+ the same.
311
+ """
312
+ struct MultiStepSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
313
+ n:: Int
314
+ batchsize:: Int
315
+ stacksize:: Int
316
+ rng:: R
317
+ end
318
+
319
+ MultiStepSampler (t:: AbstractTraces ; kw... ) = MultiStepSampler {keys(t)} (; kw... )
320
+ function MultiStepSampler {names} (; n, batchsize= 32 , stacksize= nothing , rng= Random. default_rng ()) where {names}
321
+ @assert n >= 1 " n must be ≥ 1."
322
+ ss = stacksize == 1 ? nothing : stacksize
323
+ MultiStepSampler {names, typeof(ss), typeof(rng)} (n, batchsize, ss, rng)
324
+ end
325
+
326
+ function valid_range (s:: MultiStepSampler , eb:: EpisodesBuffer )
327
+ range = copy (eb. sampleable_inds)
328
+ ns = Vector {Int} (undef, length (eb. sampleable_inds))
329
+ stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
330
+ for idx in eachindex (range)
331
+ step_number = eb. step_numbers[idx]
332
+ range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
333
+ ns[idx] = min (s. n, eb. episodes_lengths[idx] - step_number + 1 )
334
+ end
335
+ return range, ns
336
+ end
337
+
338
+ function StatsBase. sample (s:: MultiStepSampler{names} , ts) where {names}
339
+ StatsBase. sample (s, ts, Val (names))
340
+ end
341
+
342
+ function StatsBase. sample (s:: MultiStepSampler , t:: EpisodesBuffer , :: Val{names} ) where names
343
+ weights, ns = valid_range (s, t)
344
+ inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batchsize)
345
+ fetch (s, t, Val (names), inds, ns)
346
+ end
347
+
348
+ function fetch (s:: MultiStepSampler , ts:: EpisodesBuffer , :: Val{names} , inds, ns) where names
349
+ NamedTuple {names} (map (name -> collect (fetch (s, ts[name], Val (name), inds, ns[inds])), names))
350
+ end
351
+
352
+ function fetch (:: MultiStepSampler , trace, :: Val , inds, ns)
353
+ [trace[idx: (idx + ns[i] - 1 )] for (i,idx) in enumerate (inds)]
354
+ end
355
+
356
+ function fetch (s:: MultiStepSampler{names, Int} , trace:: AbstractTrace , :: Union{Val{:state}, Val{:next_state}} , inds, ns) where {names}
357
+ [trace[[idx + i + n - 1 for i in - s. stacksize+ 1 : 0 , n in 1 : ns[j]]] for (j,idx) in enumerate (inds)]
358
+ end
359
+
360
+ function StatsBase. sample (s:: MultiStepSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
361
+ t = e. traces
362
+ p = collect (deepcopy (t. priorities))
363
+ w = StatsBase. FrequencyWeights (p)
364
+ valids, ns = valid_range (s,e)
365
+ w .*= valids[1 : end - 1 ]
366
+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize)
367
+ merge (
368
+ (key= t. keys[inds], priority= p[inds]),
369
+ fetch (s, e, Val (names), inds, ns)
370
+ )
371
+ end
0 commit comments