@@ -168,86 +168,77 @@ export NStepBatchSampler
168
168
169
169
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
170
170
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
171
- that in up to `n > 1` steps later in the buffer (or the last of the episode) . The reward will be
171
+ that in up to `n > 1` steps later in the buffer. The reward will be
172
172
the discounted sum of the `n` rewards, with `γ` as the discount factor.
173
173
174
174
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
175
175
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
176
176
of partial observability, for example when the state is approximated by `stack_size` consecutive
177
177
frames.
178
178
"""
179
- mutable struct NStepBatchSampler{traces }
179
+ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} }
180
180
n:: Int # !!! n starts from 1
181
181
γ:: Float32
182
182
batch_size:: Int
183
- stack_size:: Union{Nothing,Int}
183
+ stack_size:: S
184
184
rng:: Any
185
185
end
186
186
187
187
NStepBatchSampler (; kw... ) = NStepBatchSampler {SS′ART} (; kw... )
188
188
function NStepBatchSampler {names} (; n, γ, batch_size= 32 , stack_size= nothing , rng= Random. GLOBAL_RNG) where {names}
189
189
@assert n >= 1 " n must be ≥ 1."
190
- NStepBatchSampler {names} (n, γ, batch_size, stack_size, rng)
190
+ NStepBatchSampler {names} (n, γ, batch_size, stack_size == 1 ? nothing : stack_size , rng)
191
191
end
192
192
193
-
194
- function valid_range_nbatchsampler (s:: NStepBatchSampler , ts)
195
- # think about the extreme case where s.stack_size == 1 and s.n == 1
196
- isnothing (s. stack_size) ? (1 : (length (ts)- s. n+ 1 )) : (s. stack_size: (length (ts)- s. n+ 1 ))
193
+ function valid_range_nbatchsampler (s:: NStepBatchSampler , eb:: EpisodesBuffer )
194
+ range = copy (eb. sampleable_inds)
195
+ stack_size = isnothing (s. stack_size) ? 1 : s. stack_size
196
+ for idx in eachindex (range)
197
+ valid = eb. step_numbers[idx] >= stack_size && eb. step_numbers[idx] <= eb. episodes_lengths[idx] + 1 - eb. n && eb. sampleable_inds[idx]
198
+ range[idx] = valid
199
+ end
200
+ return range
197
201
end
202
+
198
203
function StatsBase. sample (s:: NStepBatchSampler{names} , ts) where {names}
199
- valid_range = valid_range_nbatchsampler (s, ts)
200
- inds = rand (s. rng, valid_range, s. batch_size)
201
- StatsBase. sample (s, ts, Val (names), inds)
204
+ StatsBase. sample (s, ts, Val (names))
202
205
end
203
206
204
- function StatsBase. sample (s:: NStepBatchSampler{names} , ts:: EpisodesBuffer ) where {names}
205
- valid_range = valid_range_nbatchsampler (s, ts)
206
- valid_range = valid_range[valid_range .∈ (findall (ts. sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
207
- inds = rand (s. rng, valid_range, s. batch_size)
208
- StatsBase. sample (s, ts, Val (names), inds)
207
+ function StatsBase. sample (s:: NStepBatchSampler , t:: EpisodesBuffer , names)
208
+ valid_range = valid_range_nbatchsampler (s, t)
209
+ StatsBase. sample (s, t. traces, names, StatsBase. FrequencyWeights (valid_range))
209
210
end
210
211
211
-
212
- function StatsBase. sample (nbs:: NStepBatchSampler , ts, :: Val{SS′ART} , inds)
213
- if isnothing (nbs. stack_size)
214
- s = ts[:state ][inds]
215
- s′ = ts[:next_state ][inds.+ (nbs. n- 1 )]
216
- else
217
- s = ts[:state ][[x + i for i in - nbs. stack_size+ 1 : 0 , x in inds]]
218
- s′ = ts[:next_state ][[x + nbs. n - 1 + i for i in - nbs. stack_size+ 1 : 0 , x in inds]]
219
- end
220
-
221
- a = ts[:action ][inds]
222
- t_horizon = ts[:terminal ][[x + j for j in 0 : nbs. n- 1 , x in inds]]
223
- r_horizon = ts[:reward ][[x + j for j in 0 : nbs. n- 1 , x in inds]]
224
-
225
- @assert ndims (t_horizon) == 2
226
- t = any (t_horizon, dims= 1 ) |> vec
227
-
228
- @assert ndims (r_horizon) == 2
229
- r = map (eachcol (r_horizon), eachcol (t_horizon)) do r⃗, t⃗
230
- foldr (((rr, tt), init) -> rr + nbs. γ * init * (1 - tt), zip (r⃗, t⃗); init= 0.0f0 )
231
- end
232
-
233
- NamedTuple {SS′ART} (map (collect, (s, s′, a, r, t)))
212
+ function StatsBase. sample (s:: NStepBatchSampler , t:: AbstractTraces , names, weights = StatsBase. UnitWeights {Int} (length (t)))
213
+ inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batch_size)
214
+ NamedTuple{names}map (name -> collect (fetch (s, ts[name], Val (name), inds)), names)
234
215
end
235
216
236
- function StatsBase. sample (s:: NStepBatchSampler , ts, :: Val{SS′L′ART} , inds)
237
- s, s′, a, r, t = StatsBase. sample (s, ts, Val (SSART), inds)
238
- l = consecutive_view (ts[:next_legal_actions_mask ], inds)
239
- NamedTuple {SSLART} (map (collect, (s, s′, l, a, r, t)))
217
+ # state and next_state have specialized fetch methods due to stack_size
218
+ fetch (:: NStepBatchSampler{names, Nothing} , trace, :: Val{:state} , inds) where {names} = trace[inds]
219
+ fetch (s:: NStepBatchSampler{names, Int} , trace, :: Val{:state} , inds) where {names} = trace[[x + s. n - 1 + i for i in - s. stack_size+ 1 : 0 , x in inds]]
220
+ fetch (s:: NStepBatchSampler{names, Nothing} , trace, :: Val{:next_state} , inds) where {names} = trace[inds.+ (s. n- 1 )]
221
+ fetch (s:: NStepBatchSampler{names, Int} , trace, :: Val{:next_state} , inds) where {names} = trace[[x + s. n - 1 + i for i in - s. stack_size+ 1 : 0 , x in inds]]
222
+ # reward due to discounting
223
+ function fetch (s:: NStepBatchSampler{names} , trace, :: Val{:reward} , inds) where {names}
224
+ rewards = trace[[x + j for j in 0 : nbs. n- 1 , x in inds]]
225
+ return reduce ((x,y)-> x + s. γ* y, rewards, init = zero (eltype (rewards)), dims = 1 )
240
226
end
227
+ # terminal is that of the nth step
228
+ fetch (s:: NStepBatchSampler{names} , trace, :: Val{:terminal} , inds) where {names} = trace[inds.+ s. n]
229
+ # right multiplex traces must be n-step sampled
230
+ fetch (:: NStepBatchSampler{names} , trace:: RelativeTrace{1,0} , :: Val{<:Symbol} , inds) where {names} = trace[inds.+ (s. n- 1 )]
231
+ # normal trace types are fetched at inds
232
+ fetch (:: NStepBatchSampler{names} , trace, :: Val{<:Symbol} , inds) where {names} = trace[inds] # other types of trace are sampled normaly
241
233
242
234
function StatsBase. sample (s:: NStepBatchSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
243
235
t = e. traces
244
236
st = deepcopy (t. priorities)
245
- st .*= e . sampleable_inds[ 1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
237
+ st .*= valid_range_nbatchsampler (s,e) # temporary sumtree that puts 0 priority to non sampleable indices.
246
238
inds, priorities = rand (s. rng, st, s. batch_size)
247
-
248
239
merge (
249
240
(key= t. keys[inds], priority= priorities),
250
- StatsBase . sample (s, t. traces, Val (names), inds)
241
+ fetch (s, t. traces, Val (names), inds)
251
242
)
252
243
end
253
244
0 commit comments