|
1 | 1 | using Random
|
| 2 | +export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler |
2 | 3 |
|
3 | 4 | struct SampleGenerator{S,T}
|
4 | 5 | sampler::S
|
@@ -233,3 +234,42 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
|
233 | 234 | StatsBase.sample(s, t.traces, Val(names), inds)
|
234 | 235 | )
|
235 | 236 | end
|
| 237 | + |
| 238 | +""" |
| 239 | + EpisodesSampler() |
| 240 | +
|
| 241 | +A sampler that samples all Episodes present in the Trajectory and divides them into |
| 242 | +Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well. |
| 243 | +There will be at most one truncated episode and it will always be the first one. |
| 244 | +""" |
| 245 | +struct EpisodesSampler{names} |
| 246 | +end |
| 247 | + |
| 248 | +EpisodesSampler() = EpisodesSampler{nothing}() |
| 249 | +#EpisodesSampler{names}() = new{names}() |
| 250 | + |
| 251 | + |
| 252 | +struct Episode{names, N <: NamedTuple{names}} |
| 253 | + nt::N |
| 254 | +end |
| 255 | + |
| 256 | +@forward Episode.nt Base.keys, Base.haskey, Base.getindex |
| 257 | + |
| 258 | +StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = StatsBase.sample(s,t,keys(t)) |
| 259 | +StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where names = StatsBase.sample(s,t,names) |
| 260 | + |
| 261 | +function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names) |
| 262 | + ranges = UnitRange{Int}[] |
| 263 | + idx = 1 |
| 264 | + while idx < length(t) |
| 265 | + if t.sampleable_inds[idx] == 1 |
| 266 | + last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1 |
| 267 | + push!(ranges,idx:last_state_idx) |
| 268 | + idx = last_state_idx + 1 |
| 269 | + else |
| 270 | + idx += 1 |
| 271 | + end |
| 272 | + end |
| 273 | + |
| 274 | + return [Episode(NamedTuple{names}(map(x -> collect(t[Val(x)][r]), names))) for r in ranges] |
| 275 | +end |
0 commit comments