Skip to content

Commit 40fefb2

Browse files
Merge pull request #49 from JuliaReinforcementLearning/jpsl/fix
Bug fix for EpisodesBuffer sampling
2 parents ccb43cd + b06c4a3 commit 40fefb2

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.3.1"
3+
version = "0.3.2"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/episodes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function pad!(trace::Trace)
9090
return nothing
9191
end
9292

93-
pad!(buf::CircularArrayBuffer{T}) where {T,N,A} = push!(buf, zero(T))
93+
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
9494
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
9595

9696
#push a duplicate of last element as a dummy element for all 'trace' objects, ignores multiplex traces, should never be sampled.

src/samplers.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,25 @@ end
173173
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
174174
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
175175

176+
177+
function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
178+
# think about the extreme case where s.stack_size == 1 and s.n == 1
179+
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
180+
end
176181
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
177-
valid_range = isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1
182+
valid_range = valid_range_nbatchsampler(s, ts)
178183
inds = rand(s.rng, valid_range, s.batch_size)
179184
StatsBase.sample(s, ts, Val(names), inds)
180185
end
181186

187+
function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
188+
valid_range = valid_range_nbatchsampler(s, ts)
189+
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
190+
inds = rand(s.rng, valid_range, s.batch_size)
191+
StatsBase.sample(s, ts, Val(names), inds)
192+
end
193+
194+
182195
function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
183196
if isnothing(nbs.stack_size)
184197
s = ts[:state][inds]

test/samplers.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ end
138138
end
139139
#! format: on
140140

141-
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
141+
@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin
142142
n=1
143143
γ=0.99f0
144144

@@ -168,4 +168,35 @@ end
168168

169169
b = RLTrajectories.StatsBase.sample(t)
170170
@test haskey(b, :priority)
171+
@test sum(b.action .== 0) == 0
172+
end
173+
174+
175+
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
176+
n=1
177+
γ=0.99f0
178+
179+
t = Trajectory(
180+
container=CircularArraySARTSTraces(
181+
capacity=5,
182+
state=Float32 => (4,),
183+
),
184+
sampler=NStepBatchSampler{SS′ART}(
185+
n=n,
186+
γ=γ,
187+
batch_size=32,
188+
),
189+
controller=InsertSampleRatioController(
190+
threshold=100,
191+
n_inserted=-1
192+
)
193+
)
194+
195+
push!(t, (state = 1, action = true))
196+
for i = 1:9
197+
push!(t, (state = i+1, action = true, reward = i, terminal = false))
198+
end
199+
200+
b = RLTrajectories.StatsBase.sample(t)
201+
@test sum(b.action .== 0) == 0
171202
end

0 commit comments

Comments
 (0)