Skip to content

Commit 87fa7df

Browse files
committed
refactor
1 parent 7b041d3 commit 87fa7df

File tree

1 file changed

+37
-46
lines changed

1 file changed

+37
-46
lines changed

src/samplers.jl

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -168,86 +168,77 @@ export NStepBatchSampler
168168
169169
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
170170
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
171-
that in up to `n > 1` steps later in the buffer (or the last of the episode). The reward will be
171+
that in up to `n > 1` steps later in the buffer. The reward will be
172172
the discounted sum of the `n` rewards, with `γ` as the discount factor.
173173
174174
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
175175
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
176176
of partial observability, for example when the state is approximated by `stack_size` consecutive
177177
frames.
178178
"""
179-
mutable struct NStepBatchSampler{traces}
179+
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
180180
n::Int # !!! n starts from 1
181181
γ::Float32
182182
batch_size::Int
183-
stack_size::Union{Nothing,Int}
183+
stack_size::S
184184
rng::Any
185185
end
186186

187187
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
188188
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
189189
@assert n >= 1 "n must be ≥ 1."
190-
NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
190+
NStepBatchSampler{names}(n, γ, batch_size, stack_size == 1 ? nothing : stack_size, rng)
191191
end
192192

193-
194-
function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
195-
# think about the extreme case where s.stack_size == 1 and s.n == 1
196-
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
193+
function valid_range_nbatchsampler(s::NStepBatchSampler, eb::EpisodesBuffer)
194+
range = copy(eb.sampleable_inds)
195+
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
196+
for idx in eachindex(range)
197+
valid = eb.step_numbers[idx] >= stack_size && eb.step_numbers[idx] <= eb.episodes_lengths[idx] + 1 - eb.n && eb.sampleable_inds[idx]
198+
range[idx] = valid
199+
end
200+
return range
197201
end
202+
198203
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
199-
valid_range = valid_range_nbatchsampler(s, ts)
200-
inds = rand(s.rng, valid_range, s.batch_size)
201-
StatsBase.sample(s, ts, Val(names), inds)
204+
StatsBase.sample(s, ts, Val(names))
202205
end
203206

204-
function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
205-
valid_range = valid_range_nbatchsampler(s, ts)
206-
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
207-
inds = rand(s.rng, valid_range, s.batch_size)
208-
StatsBase.sample(s, ts, Val(names), inds)
207+
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, names)
208+
valid_range = valid_range_nbatchsampler(s, t)
209+
StatsBase.sample(s, t.traces, names, StatsBase.FrequencyWeights(valid_range))
209210
end
210211

211-
212-
function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
213-
if isnothing(nbs.stack_size)
214-
s = ts[:state][inds]
215-
s′ = ts[:next_state][inds.+(nbs.n-1)]
216-
else
217-
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
218-
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
219-
end
220-
221-
a = ts[:action][inds]
222-
t_horizon = ts[:terminal][[x + j for j in 0:nbs.n-1, x in inds]]
223-
r_horizon = ts[:reward][[x + j for j in 0:nbs.n-1, x in inds]]
224-
225-
@assert ndims(t_horizon) == 2
226-
t = any(t_horizon, dims=1) |> vec
227-
228-
@assert ndims(r_horizon) == 2
229-
r = map(eachcol(r_horizon), eachcol(t_horizon)) do r⃗, t⃗
230-
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
231-
end
232-
233-
NamedTuple{SS′ART}(map(collect, (s, s′, a, r, t)))
212+
function StatsBase.sample(s::NStepBatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
213+
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
214+
NamedTuple{names}map(name -> collect(fetch(s, ts[name], Val(name), inds)), names)
234215
end
235216

236-
function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
237-
s, s′, a, r, t = StatsBase.sample(s, ts, Val(SSART), inds)
238-
l = consecutive_view(ts[:next_legal_actions_mask], inds)
239-
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
217+
#state and next_state have specialized fetch methods due to stack_size
218+
fetch(::NStepBatchSampler{names, Nothing}, trace, ::Val{:state}, inds) where {names} = trace[inds]
219+
fetch(s::NStepBatchSampler{names, Int}, trace, ::Val{:state}, inds) where {names} = trace[[x + s.n - 1 + i for i in -s.stack_size+1:0, x in inds]]
220+
fetch(s::NStepBatchSampler{names, Nothing}, trace, ::Val{:next_state}, inds) where {names} = trace[inds.+(s.n-1)]
221+
fetch(s::NStepBatchSampler{names, Int}, trace, ::Val{:next_state}, inds) where {names} = trace[[x + s.n - 1 + i for i in -s.stack_size+1:0, x in inds]]
222+
#reward due to discounting
223+
function fetch(s::NStepBatchSampler{names}, trace, ::Val{:reward}, inds) where {names}
224+
rewards = trace[[x + j for j in 0:nbs.n-1, x in inds]]
225+
return reduce((x,y)->x + s.γ*y, rewards, init = zero(eltype(rewards)), dims = 1)
240226
end
227+
#terminal is that of the nth step
228+
fetch(s::NStepBatchSampler{names}, trace, ::Val{:terminal}, inds) where {names} = trace[inds.+s.n]
229+
#right multiplex traces must be n-step sampled
230+
fetch(::NStepBatchSampler{names}, trace::RelativeTrace{1,0} , ::Val{<:Symbol}, inds) where {names} = trace[inds.+(s.n-1)]
231+
#normal trace types are fetched at inds
232+
fetch(::NStepBatchSampler{names}, trace, ::Val{<:Symbol}, inds) where {names} = trace[inds] #other types of trace are sampled normaly
241233

242234
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
243235
t = e.traces
244236
st = deepcopy(t.priorities)
245-
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
237+
st .*= valid_range_nbatchsampler(s,e) #temporary sumtree that puts 0 priority to non sampleable indices.
246238
inds, priorities = rand(s.rng, st, s.batch_size)
247-
248239
merge(
249240
(key=t.keys[inds], priority=priorities),
250-
StatsBase.sample(s, t.traces, Val(names), inds)
241+
fetch(s, t.traces, Val(names), inds)
251242
)
252243
end
253244

0 commit comments

Comments
 (0)