Skip to content

Commit 0e76ebb

Browse files
Merge pull request #51 from JuliaReinforcementLearning/episodesampler
2 parents 937f1c6 + 5635aa1 commit 0e76ebb

File tree

4 files changed

+273
-186
lines changed

4 files changed

+273
-186
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.3.2"
3+
version = "0.3.3"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ julia> for batch in t
6969
- `BatchSampler`
7070
- `MetaSampler`
7171
- `MultiBatchSampler`
72+
- `EpisodesSampler`
7273

7374
**Controllers**
7475

src/samplers.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Random
2+
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
23

34
struct SampleGenerator{S,T}
45
sampler::S
@@ -233,3 +234,42 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
233234
StatsBase.sample(s, t.traces, Val(names), inds)
234235
)
235236
end
237+
238+
"""
239+
EpisodesSampler()
240+
241+
A sampler that samples all Episodes present in the Trajectory and divides them into
242+
Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well.
243+
There will be at most one truncated episode and it will always be the first one.
244+
"""
245+
struct EpisodesSampler{names}
246+
end
247+
248+
EpisodesSampler() = EpisodesSampler{nothing}()
249+
#EpisodesSampler{names}() = new{names}()
250+
251+
252+
struct Episode{names, N <: NamedTuple{names}}
253+
nt::N
254+
end
255+
256+
@forward Episode.nt Base.keys, Base.haskey, Base.getindex
257+
258+
StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = StatsBase.sample(s,t,keys(t))
259+
StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where names = StatsBase.sample(s,t,names)
260+
261+
function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
262+
ranges = UnitRange{Int}[]
263+
idx = 1
264+
while idx < length(t)
265+
if t.sampleable_inds[idx] == 1
266+
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1
267+
push!(ranges,idx:last_state_idx)
268+
idx = last_state_idx + 1
269+
else
270+
idx += 1
271+
end
272+
end
273+
274+
return [Episode(NamedTuple{names}(map(x -> collect(t[Val(x)][r]), names))) for r in ranges]
275+
end

0 commit comments

Comments
 (0)