Skip to content

Commit 70f1a1f

Browse files
authored
Merge pull request #58 from JuliaReinforcementLearning/MultiStepSampler
Multi step sampler
2 parents e55919c + 7d2b8ee commit 70f1a1f

File tree

3 files changed

+153
-37
lines changed

3 files changed

+153
-37
lines changed

src/normalization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ for f in (:push!, :pushfirst!, :append!, :prepend!)
192192
end
193193

194194
function StatsBase.sample(s::BatchSampler, nt::NormalizedTraces, names, weights = StatsBase.UnitWeights{Int}(length(nt)))
195-
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batch_size)
195+
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batchsize)
196196
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
197197
NamedTuple{names}(collect(maybe_normalize(nt[x][inds], x)) for x in names)
198198
end

src/samplers.jl

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Random
2-
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
2+
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler
33

44
struct SampleGenerator{S,T}
55
sampler::S
@@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
2929
export BatchSampler
3030

3131
struct BatchSampler{names}
32-
batch_size::Int
32+
batchsize::Int
3333
rng::Random.AbstractRNG
3434
end
3535

3636
"""
37-
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
38-
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
37+
BatchSampler{names}(;batchsize, rng=Random.GLOBAL_RNG)
38+
BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)
3939
40-
Uniformly sample **ONE** batch of `batch_size` examples for each trace specified
40+
Uniformly sample **ONE** batch of `batchsize` examples for each trace specified
4141
in `names`. If `names` is not set, all the traces will be sampled.
4242
"""
43-
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
43+
BatchSampler(batchsize; kw...) = BatchSampler(; batchsize=batchsize, kw...)
4444
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
45-
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
46-
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batch_size, rng)
45+
BatchSampler{names}(batchsize; kw...) where {names} = BatchSampler{names}(; batchsize=batchsize, kw...)
46+
BatchSampler{names}(; batchsize, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batchsize, rng)
4747

4848
StatsBase.sample(s::BatchSampler{nothing}, t::AbstractTraces) = StatsBase.sample(s, t, keys(t))
4949
StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = StatsBase.sample(s, t, names)
5050

5151
function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
52-
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
52+
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batchsize)
5353
NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names))
5454
end
5555

@@ -75,12 +75,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7575
p = collect(deepcopy(t.priorities))
7676
w = StatsBase.FrequencyWeights(p)
7777
w .*= e.sampleable_inds[1:end-1]
78-
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
78+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
7979
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
8080
end
8181

8282
function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
83-
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
83+
inds, priorities = rand(s.rng, t.priorities, s.batchsize)
8484
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
8585
end
8686

@@ -165,41 +165,42 @@ end
165165
export NStepBatchSampler
166166

167167
"""
168-
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)
168+
169+
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
169170
170171
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
171172
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172173
that in up to `n > 1` steps later in the buffer. The reward will be
173174
the discounted sum of the `n` rewards, with `γ` as the discount factor.
174175
175-
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
176-
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
177-
of partial observability, for example when the state is approximated by `stack_size` consecutive
176+
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
177+
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
178+
of partial observability, for example when the state is approximated by `stacksize` consecutive
178179
frames.
179180
"""
180-
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
181+
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
181182
n::Int # !!! n starts from 1
182183
γ::Float32
183-
batch_size::Int
184-
stack_size::S
185-
rng::Any
184+
batchsize::Int
185+
stacksize::S
186+
rng::R
186187
end
187188

188189
NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
189-
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
190+
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
190191
@assert n >= 1 "n must be ≥ 1."
191-
ss = stack_size == 1 ? nothing : stack_size
192-
NStepBatchSampler{names, typeof(ss)}(n, γ, batch_size, ss, rng)
192+
ss = stacksize == 1 ? nothing : stacksize
193+
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
193194
end
194195

195-
#return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
196+
#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
196197
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
197198
range = copy(eb.sampleable_inds)
198199
ns = Vector{Int}(undef, length(eb.sampleable_inds))
199-
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
200+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
200201
for idx in eachindex(range)
201202
step_number = eb.step_numbers[idx]
202-
range[idx] = step_number >= stack_size && eb.sampleable_inds[idx]
203+
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
203204
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
204205
end
205206
return range, ns
@@ -211,19 +212,19 @@ end
211212

212213
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
213214
weights, ns = valid_range(s, t)
214-
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batch_size)
215+
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
215216
fetch(s, t, Val(names), inds, ns)
216217
end
217218

218219
function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
219220
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
220221
end
221222

222-
#state and next_state have specialized fetch methods due to stack_size
223+
#state and next_state have specialized fetch methods due to stacksize
223224
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
224-
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stack_size+1:0, x in inds]]
225+
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stacksize+1:0, x in inds]]
225226
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
226-
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stack_size+1:0, (idx,x) in enumerate(inds)]]
227+
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stacksize+1:0, (idx,x) in enumerate(inds)]]
227228

228229
#reward due to discounting
229230
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
@@ -247,7 +248,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247248
w = StatsBase.FrequencyWeights(p)
248249
valids, ns = valid_range(s,e)
249250
w .*= valids[1:end-1]
250-
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
251+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
251252
merge(
252253
(key=t.keys[inds], priority=p[inds]),
253254
fetch(s, e, Val(names), inds, ns)
@@ -297,3 +298,74 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
297298

298299
return [make_episode(t, r, names) for r in ranges]
299300
end
301+
302+
#####MultiStepSampler
303+
304+
"""
305+
MultiStepSampler{names}(batchsize, stacksize, n, rng)
306+
307+
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308+
`x`. The samples are returned in an array of batchsize elements. For each element, n is
309+
truncated by the end of its episode. This means that the dimensions of each sample are not
310+
the same.
311+
"""
312+
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
313+
n::Int
314+
batchsize::Int
315+
stacksize::Int
316+
rng::R
317+
end
318+
319+
MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
320+
function MultiStepSampler{names}(; n, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
321+
@assert n >= 1 "n must be ≥ 1."
322+
ss = stacksize == 1 ? nothing : stacksize
323+
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
324+
end
325+
326+
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
327+
range = copy(eb.sampleable_inds)
328+
ns = Vector{Int}(undef, length(eb.sampleable_inds))
329+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
330+
for idx in eachindex(range)
331+
step_number = eb.step_numbers[idx]
332+
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
333+
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
334+
end
335+
return range, ns
336+
end
337+
338+
function StatsBase.sample(s::MultiStepSampler{names}, ts) where {names}
339+
StatsBase.sample(s, ts, Val(names))
340+
end
341+
342+
function StatsBase.sample(s::MultiStepSampler, t::EpisodesBuffer, ::Val{names}) where names
343+
weights, ns = valid_range(s, t)
344+
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize)
345+
fetch(s, t, Val(names), inds, ns)
346+
end
347+
348+
function fetch(s::MultiStepSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
349+
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
350+
end
351+
352+
function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
353+
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
354+
end
355+
356+
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
357+
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
358+
end
359+
360+
function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
361+
t = e.traces
362+
p = collect(deepcopy(t.priorities))
363+
w = StatsBase.FrequencyWeights(p)
364+
valids, ns = valid_range(s,e)
365+
w .*= valids[1:end-1]
366+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
367+
merge(
368+
(key=t.keys[inds], priority=p[inds]),
369+
fetch(s, e, Val(names), inds, ns)
370+
)
371+
end

test/samplers.jl

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ import ReinforcementLearningTrajectories.fetch
7878
γ = 0.99
7979
n_stack = 2
8080
n_horizon = 3
81-
batch_size = 1000
81+
batchsize = 1000
8282
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
83-
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
83+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stacksize=n_stack, batchsize=batchsize)
8484

8585
push!(eb, (state = 1, action = 1))
8686
for i = 1:5
@@ -98,12 +98,12 @@ import ReinforcementLearningTrajectories.fetch
9898
for key in keys(eb)
9999
@test haskey(batch, key)
100100
end
101-
#state: samples with stack_size
101+
#state: samples with stacksize
102102
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
103103
@test states == [1 2 3 4 7 8 9;
104104
2 3 4 5 8 9 10]
105105
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
106-
#next_state: samples with stack_size and nsteps forward
106+
#next_state: samples with stacksize and nsteps forward
107107
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
108108
@test next_states == [4 5 5 5 10 10 10;
109109
5 6 6 6 11 11 11]
@@ -127,9 +127,9 @@ import ReinforcementLearningTrajectories.fetch
127127
### CircularPrioritizedTraces and NStepBatchSampler
128128
γ = 0.99
129129
n_horizon = 3
130-
batch_size = 4
130+
batchsize = 4
131131
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
132-
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size)
132+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)
133133

134134
push!(eb, (state = 1, action = 1))
135135
for i = 1:5
@@ -196,4 +196,48 @@ import ReinforcementLearningTrajectories.fetch
196196
@test length(b[2][:state]) == 5
197197
@test !haskey(b[1], :action)
198198
end
199+
@testset "MultiStepSampler" begin
200+
n_stack = 2
201+
n_horizon = 3
202+
batchsize = 1000
203+
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
204+
s1 = MultiStepSampler(eb, n=n_horizon, stacksize=n_stack, batchsize=batchsize)
205+
206+
push!(eb, (state = 1, action = 1))
207+
for i = 1:5
208+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
209+
end
210+
push!(eb, (state = 7, action = 7))
211+
for (j,i) = enumerate(8:11)
212+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
213+
end
214+
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
215+
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
216+
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
217+
inds = [i for i in eachindex(weights) if weights[i] == 1]
218+
batch = sample(s1, eb)
219+
for key in keys(eb)
220+
@test haskey(batch, key)
221+
end
222+
#state and next_state: samples with stacksize
223+
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
224+
@test states == [[1 2 3; 2 3 4], [2 3 4; 3 4 5], [3 4; 4 5], [4; 5;;], [7 8 9; 8 9 10], [8 9; 9 10], [9; 10;;]]
225+
@test all(in(states), batch[:state])
226+
#next_state: samples with stacksize and nsteps forward
227+
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
228+
@test next_states == [[2 3 4; 3 4 5], [3 4 5; 4 5 6], [4 5; 5 6], [5; 6;;], [8 9 10; 9 10 11], [9 10; 10 11], [10; 11;;]]
229+
@test all(in(next_states), batch[:next_state])
230+
#all other traces sample normally
231+
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
232+
@test actions == [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]]
233+
@test all(in(actions), batch[:action])
234+
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
235+
@test next_actions == [a .+ 1 for a in [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]]]
236+
@test all(in(next_actions), batch[:next_action])
237+
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
238+
@test rewards == actions
239+
@test all(in(rewards), batch[:reward])
240+
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
241+
@test terminals == [[a == 5 ? 1 : 0 for a in acs] for acs in actions]
242+
end
199243
end

0 commit comments

Comments
 (0)