Skip to content

Commit 7b041d3

Browse files
committed
docsstring
1 parent 85de617 commit 7b041d3

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/samplers.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ end
163163

164164
export NStepBatchSampler
165165

166+
"""
167+
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)
168+
169+
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
170+
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
171+
that in up to `n > 1` steps later in the buffer (or the last of the episode). The reward will be
172+
the discounted sum of the `n` rewards, with `γ` as the discount factor.
173+
174+
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
175+
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
176+
of partial observability, for example when the state is approximated by `stack_size` consecutive
177+
frames.
178+
"""
166179
mutable struct NStepBatchSampler{traces}
167180
n::Int # !!! n starts from 1
168181
γ::Float32
@@ -172,7 +185,10 @@ mutable struct NStepBatchSampler{traces}
172185
end
173186

174187
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
175-
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
188+
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
189+
@assert n >= 1 "n must be ≥ 1."
190+
NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
191+
end
176192

177193

178194
function valid_range_nbatchsampler(s::NStepBatchSampler, ts)

0 commit comments

Comments
 (0)