Skip to content

Commit 4de0dce

Browse files
Add test and fix for sampler dispatch issue
1 parent 2be2aa3 commit 4de0dce

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

src/samplers.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,21 @@ function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
209209
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
210210
end
211211

212-
function StatsBase.sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where {names}
213-
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
212+
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces})
213+
t = e.traces
214+
st = deepcopy(t.priorities)
215+
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
216+
inds, priorities = rand(s.rng, st, s.batch_size)
217+
218+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[x][inds]), names)...))
219+
end
220+
221+
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
222+
t = e.traces
223+
st = deepcopy(t.priorities)
224+
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
225+
inds, priorities = rand(s.rng, st, s.batch_size)
226+
214227
merge(
215228
(key=t.keys[inds], priority=priorities),
216229
StatsBase.sample(s, t.traces, Val(names), inds)

test/samplers.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,35 @@ end
161161
end
162162
end
163163
#! format: on
164+
165+
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
166+
n=1
167+
γ=0.99f0
168+
169+
t = Trajectory(
170+
container=CircularPrioritizedTraces(
171+
CircularArraySARTSTraces(
172+
capacity=1000,
173+
state=Float32 => (4,),
174+
);
175+
default_priority=100.0f0
176+
),
177+
sampler=NStepBatchSampler{SS′ART}(
178+
n=n,
179+
γ=γ,
180+
batch_size=32,
181+
),
182+
controller=InsertSampleRatioController(
183+
threshold=100,
184+
n_inserted=-1
185+
)
186+
)
187+
188+
push!(t, (state = 1, action = true))
189+
for i = 1:9
190+
push!(t, (state = i+1, action = true, reward = i, terminal = false))
191+
end
192+
193+
b = RLTrajectories.StatsBase.sample(t)
194+
@test haskey(b, :priority)
195+
end

0 commit comments

Comments
 (0)