Skip to content

Commit 39ea7a5

Browse files
drop test / function
1 parent 60b7d35 commit 39ea7a5

File tree

2 files changed

+1
-34
lines changed

2 files changed

+1
-34
lines changed

src/samplers.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
184184
s = ts[:state][inds]
185185
s′ = ts[:next_state][inds.+(nbs.n-1)]
186186
else
187-
s = ts[:state][[x + i + 1 for i in -nbs.stack_size+1:0, x in inds]]
187+
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
188188
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
189189
end
190190

@@ -220,12 +220,3 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
220220
StatsBase.sample(s, t.traces, Val(names), inds)
221221
)
222222
end
223-
224-
function StatsBase.sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where {names}
225-
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
226-
227-
merge(
228-
(key=t.keys[inds], priority=priorities),
229-
StatsBase.sample(s, t.traces, Val(names), inds)
230-
)
231-
end

test/samplers.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -135,30 +135,6 @@ end
135135
@test xs3.reward[1] 3 + γ * 4 # terminated at step 4
136136
@test xs3.reward[2] 5 + γ * (6 + γ * 7)
137137
@test xs3.reward[3] 7 + γ * (8 + γ * 9)
138-
139-
@testset "CircularPrioritizedTraces with NStepBatchSampler" begin
140-
γ = 0.9
141-
n_stack = 2
142-
n_horizon = 3
143-
batch_size = 4
144-
145-
t = CircularPrioritizedTraces(
146-
CircularArraySARTSATraces(;
147-
capacity=10
148-
),
149-
default_priority=1.0f0
150-
)
151-
s = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
152-
push!(t, (state = 1, action = true))
153-
for i = 1:9
154-
push!(t, (state = i+1, action = true, reward = i, terminal = false))
155-
end
156-
157-
xs = RLTrajectories.StatsBase.sample(s, t)
158-
@test haskey(xs, :state)
159-
@test haskey(xs, :priority)
160-
@test haskey(xs, :key)
161-
end
162138
end
163139
#! format: on
164140

0 commit comments

Comments
 (0)