Skip to content

Commit 2be2aa3

Browse files
Fix interaction between CircularPrioritizedTraces and NStepBatchSampler
1 parent e3179da commit 2be2aa3

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
184184
s = ts[:state][inds]
185185
s′ = ts[:next_state][inds.+(nbs.n-1)]
186186
else
187-
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
187+
s = ts[:state][[x + i + 1 for i in -nbs.stack_size+1:0, x in inds]]
188188
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
189189
end
190190

test/samplers.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,29 @@ end
135135
@test xs3.reward[1] 3 + γ * 4 # terminated at step 4
136136
@test xs3.reward[2] 5 + γ * (6 + γ * 7)
137137
@test xs3.reward[3] 7 + γ * (8 + γ * 9)
138+
139+
@testset "CircularPrioritizedTraces with NStepBatchSampler" begin
140+
γ = 0.9
141+
n_stack = 2
142+
n_horizon = 3
143+
batch_size = 4
144+
145+
t = CircularPrioritizedTraces(
146+
CircularArraySARTSATraces(;
147+
capacity=10
148+
),
149+
default_priority=1.0f0
150+
)
151+
s = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
152+
push!(t, (state = 1, action = true))
153+
for i = 1:9
154+
push!(t, (state = i+1, action = true, reward = i, terminal = false))
155+
end
156+
157+
xs = RLTrajectories.StatsBase.sample(s, t)
158+
@test haskey(xs, :state)
159+
@test haskey(xs, :priority)
160+
@test haskey(xs, :key)
161+
end
138162
end
139-
#! format: on
163+
#! format: on

0 commit comments

Comments
 (0)