Skip to content

Commit 94f8df7

Browse files
authored
Merge pull request #61 from JuliaReinforcementLearning/mssfix
fix ms sampler
2 parents 70f1a1f + c92b776 commit 94f8df7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/samplers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ end
302302
#####MultiStepSampler
303303

304304
"""
305-
MultiStepSampler{names}(batchsize, stacksize, n, rng)
305+
MultiStepSampler{names}(batchsize, n, stacksize, rng)
306306
307307
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308308
`x`. The samples are returned in an array of batchsize elements. For each element, n is
@@ -312,12 +312,12 @@ the same.
312312
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
313313
n::Int
314314
batchsize::Int
315-
stacksize::Int
315+
stacksize::S
316316
rng::R
317317
end
318318

319319
MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
320-
function MultiStepSampler{names}(; n, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
320+
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
321321
@assert n >= 1 "n must be ≥ 1."
322322
ss = stacksize == 1 ? nothing : stacksize
323323
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)

0 commit comments

Comments
 (0)