|
1 | 1 | using Random
|
2 | 2 |
|
3 |
| -abstract type AbstractSampler end |
| 3 | +struct SampleGenerator{S,T} |
| 4 | + sampler::S |
| 5 | + traces::T |
| 6 | +end |
| 7 | + |
| 8 | +Base.iterate(s::SampleGenerator) = sample(s.sampler, s.traces), nothing |
| 9 | +Base.iterate(s::SampleGenerator, ::Nothing) = nothing |
4 | 10 |
|
5 | 11 | #####
|
6 | 12 | # DummySampler
|
7 | 13 | #####
|
8 | 14 |
|
9 | 15 | export DummySampler
|
10 | 16 |
|
| 17 | +""" |
| 18 | +Just return the underlying traces. |
| 19 | +""" |
11 | 20 | struct DummySampler end
|
12 | 21 |
|
13 |
| -sample(s::DummySampler, t::AbstractTraces) = t |
| 22 | +sample(::DummySampler, t) = t |
14 | 23 |
|
15 | 24 | #####
|
16 | 25 | # BatchSampler
|
17 | 26 | #####
|
18 | 27 |
|
19 | 28 | export BatchSampler
|
20 |
| -struct BatchSampler{names} <: AbstractSampler |
| 29 | + |
| 30 | +struct BatchSampler{names} |
21 | 31 | batch_size::Int
|
22 | 32 | rng::Random.AbstractRNG
|
23 | 33 | end
|
|
26 | 36 | BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
|
27 | 37 | BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
|
28 | 38 |
|
29 |
| -Uniformly sample a batch of examples for each trace specified in `names`. |
30 |
| -By default, all the traces will be sampled. |
31 |
| -
|
32 |
| -See also [`sample`](@ref). |
| 39 | +Uniformly sample **ONE** batch of `batch_size` examples for each trace specified |
| 40 | +in `names`. If `names` is not set, all the traces will be sampled. |
33 | 41 | """
|
34 | 42 | BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
|
35 | 43 | BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
|
36 | 44 | BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
|
37 | 45 | BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batch_size, rng)
|
38 | 46 |
|
39 | 47 | sample(s::BatchSampler{nothing}, t::AbstractTraces) = sample(s, t, keys(t))
|
40 |
| -sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, names) |
| 48 | +sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = _sample(s, t, names) |
41 | 49 |
|
42 | 50 | function sample(s::BatchSampler, t::AbstractTraces, names)
|
43 | 51 | inds = rand(s.rng, 1:length(t), s.batch_size)
|
44 | 52 | NamedTuple{names}(map(x -> t[x][inds], names))
|
45 | 53 | end
|
46 | 54 |
|
| 55 | +# !!! avoid iterating an empty trajectory |
| 56 | +function Base.iterate(s::SampleGenerator{<:BatchSampler}) |
| 57 | + if length(s.traces) > 0 |
| 58 | + sample(s.sampler, s.traces), nothing |
| 59 | + else |
| 60 | + nothing |
| 61 | + end |
| 62 | +end |
| 63 | + |
47 | 64 | #####
|
48 | 65 |
|
49 | 66 | sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = sample(s, t, keys(t.traces))
|
@@ -75,7 +92,7 @@ initializing an agent.
|
75 | 92 | MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100))
|
76 | 93 | ```
|
77 | 94 | """
|
78 |
| -struct MetaSampler{names,T} <: AbstractSampler |
| 95 | +struct MetaSampler{names,T} |
79 | 96 | samplers::NamedTuple{names,T}
|
80 | 97 | end
|
81 | 98 |
|
@@ -104,7 +121,7 @@ MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
|
104 | 121 | critic = MultiBatchSampler(BatchSampler(100), 5))
|
105 | 122 | ```
|
106 | 123 | """
|
107 |
| -struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler |
| 124 | +struct MultiBatchSampler{S} |
108 | 125 | sampler::S
|
109 | 126 | n::Int
|
110 | 127 | end
|
|
0 commit comments