Skip to content

Commit 29b7cb4

Browse files
authored
Merge pull request #33 from JuliaReinforcementLearning/test_on_GPU
Fix #29
2 parents 50439ba + 2e55c19 commit 29b7cb4

File tree

2 files changed

+54
-39
lines changed

2 files changed

+54
-39
lines changed

src/samplers.jl

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
using Random
22

3-
abstract type AbstractSampler end
3+
struct SampleGenerator{S,T}
4+
sampler::S
5+
traces::T
6+
end
7+
8+
Base.iterate(s::SampleGenerator) = sample(s.sampler, s.traces), nothing
9+
Base.iterate(s::SampleGenerator, ::Nothing) = nothing
410

511
#####
612
# DummySampler
713
#####
814

915
export DummySampler
1016

17+
"""
18+
Just return the underlying traces.
19+
"""
1120
struct DummySampler end
1221

13-
sample(s::DummySampler, t::AbstractTraces) = t
22+
sample(::DummySampler, t) = t
1423

1524
#####
1625
# BatchSampler
1726
#####
1827

1928
export BatchSampler
20-
struct BatchSampler{names} <: AbstractSampler
29+
30+
struct BatchSampler{names}
2131
batch_size::Int
2232
rng::Random.AbstractRNG
2333
end
@@ -26,10 +36,8 @@ end
2636
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
2737
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
2838
29-
Uniformly sample a batch of examples for each trace specified in `names`.
30-
By default, all the traces will be sampled.
31-
32-
See also [`sample`](@ref).
39+
Uniformly sample **ONE** batch of `batch_size` examples for each trace specified
40+
in `names`. If `names` is not set, all the traces will be sampled.
3341
"""
3442
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
3543
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
@@ -44,6 +52,15 @@ function sample(s::BatchSampler, t::AbstractTraces, names)
4452
NamedTuple{names}(map(x -> t[x][inds], names))
4553
end
4654

55+
# !!! avoid iterating an empty trajectory
56+
function Base.iterate(s::SampleGenerator{<:BatchSampler})
57+
if length(s.traces) > 0
58+
sample(s.sampler, s.traces), nothing
59+
else
60+
nothing
61+
end
62+
end
63+
4764
#####
4865

4966
sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = sample(s, t, keys(t.traces))
@@ -75,7 +92,7 @@ initializing an agent.
7592
MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100))
7693
```
7794
"""
78-
struct MetaSampler{names,T} <: AbstractSampler
95+
struct MetaSampler{names,T}
7996
samplers::NamedTuple{names,T}
8097
end
8198

@@ -104,7 +121,7 @@ MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
104121
critic = MultiBatchSampler(BatchSampler(100), 5))
105122
```
106123
"""
107-
struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler
124+
struct MultiBatchSampler{S}
108125
sampler::S
109126
n::Int
110127
end

src/trajectory.jl

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,12 @@ function Base.bind(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}
6969
bind(t.controler.ch_out, task)
7070
end
7171

72-
# !!! by default we assume `x` is a complete example which contains all the traces
73-
# When doing partial inserting, the result of undefined
74-
function Base.push!(t::Trajectory, x)
75-
push!(t.container, x)
76-
on_insert!(t.controller, 1)
77-
end
78-
7972
Base.setindex!(t::Trajectory, v, I...) = setindex!(t.container, v, I...)
8073

74+
#####
75+
# in
76+
#####
77+
8178
struct CallMsg
8279
f::Any
8380
args::Tuple
@@ -93,32 +90,33 @@ function Base.append!(t::Trajectory, x)
9390
on_insert!(t.controller, length(x))
9491
end
9592

96-
# !!! bypass the controller
97-
sample(t::Trajectory) = sample(t.sampler, t.container)
93+
# !!! by default we assume `x` is a complete example which contains all the traces
94+
# When doing partial inserting, the result of undefined
95+
function Base.push!(t::Trajectory, x)
96+
push!(t.container, x)
97+
on_insert!(t)
98+
end
9899

99-
on_sample!(t::Trajectory) = on_sample!(t.controller)
100+
on_insert!(t::Trajectory) = on_insert!(t, 1)
101+
on_insert!(t::Trajectory, n::Int) = on_insert!(t.controller, n)
100102

101-
function Base.take!(t::Trajectory)
102-
if on_sample!(t)
103-
sample(t.sampler, t.container) |> t.transformer
104-
else
105-
nothing
106-
end
107-
end
103+
#####
104+
# out
105+
#####
108106

109-
function Base.iterate(t::Trajectory)
110-
x = take!(t)
111-
if isnothing(x)
112-
nothing
113-
else
114-
x, true
115-
end
116-
end
107+
SampleGenerator(t::Trajectory) = SampleGenerator(t.sampler, t.container)
108+
109+
on_sample!(t::Trajectory) = on_sample!(t.controller)
110+
sample(t::Trajectory) = sample(t.sampler, t.container)
111+
112+
"""
113+
Keep sampling batches from the trajectory until the trajectory is not ready to
114+
be sampled yet due to the `controller`.
115+
"""
116+
iter(t::Trajectory) = Iterators.takewhile(_ -> on_sample!(t), Iterators.cycle(SampleGenerator(t)))
117117

118-
Base.iterate(t::Trajectory, state) = iterate(t)
118+
Base.iterate(t::Trajectory, args...) = iterate(iter(t), args...)
119+
Base.IteratorSize(t::Trajectory) = Base.IteratorSize(iter(t))
119120

120121
Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...) = iterate(t.controller.ch_out, args...)
121-
Base.take!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = take!(t.controller.ch_out)
122-
123-
Base.IteratorSize(::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = Base.IsInfinite()
124-
Base.IteratorSize(::Trajectory) = Base.SizeUnknown()
122+
Base.IteratorSize(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = Base.IteratorSize(t.controller.ch_out)

0 commit comments

Comments
 (0)