Skip to content

Commit 85de617

Browse files
authored
Merge pull request #53 from JuliaReinforcementLearning/episodesampler
fix invalid sampling with EpisodesSampler
2 parents dc1c3ad + 147fb6d commit 85de617

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

src/samplers.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,18 +258,23 @@ end
258258
StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = StatsBase.sample(s,t,keys(t))
259259
StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where names = StatsBase.sample(s,t,names)
260260

261+
function make_episode(t::EpisodesBuffer, range, names)
262+
nt = NamedTuple{names}(map(x -> collect(t[Val(x)][range]), names))
263+
Episode(nt)
264+
end
265+
261266
function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
262267
ranges = UnitRange{Int}[]
263268
idx = 1
264269
while idx < length(t)
265270
if t.sampleable_inds[idx] == 1
266-
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1
271+
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx]
267272
push!(ranges,idx:last_state_idx)
268273
idx = last_state_idx + 1
269274
else
270275
idx += 1
271276
end
272277
end
273278

274-
return [Episode(NamedTuple{names}(map(x -> collect(t[Val(x)][r]), names))) for r in ranges]
279+
return [make_episode(t, r, names) for r in ranges]
275280
end

test/samplers.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
@test size(b.action) == (sz,)
1515

1616
#In EpisodesBuffer
17-
eb = EpisodesBuffer(CircularArraySARTSTraces(capacity=10))
17+
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
1818
push!(eb, (state = 1, action = 1))
1919
for i = 1:5
2020
push!(eb, (state = i+1, action =i+1, reward = i, terminal = false))
@@ -205,19 +205,25 @@
205205
@testset "EpisodesSampler" begin
206206
s = EpisodesSampler()
207207
eb = EpisodesBuffer(CircularArraySARTSTraces(capacity=10))
208-
push!(eb, (state = 1, action = 1))
208+
push!(eb, (state = 1,))
209209
for i = 1:5
210-
push!(eb, (state = i+1, action =i+1, reward = i, terminal = false))
210+
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
211211
end
212-
push!(eb, (state = 7, action = 7))
212+
push!(eb, (state = 7,))
213213
for (j,i) = enumerate(8:12)
214-
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
214+
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
215215
end
216216

217217
b = sample(s, eb)
218218
@test length(b) == 2
219-
@test length(b[1][:state]) == 5
220-
@test length(b[2][:state]) == 6
219+
@test b[1][:state] == [2:5;]
220+
@test b[1][:next_state] == [3:6;]
221+
@test b[1][:action] == [2:5;]
222+
@test b[1][:reward] == [2:5;]
223+
@test b[2][:state] == [7:11;]
224+
@test b[2][:next_state] == [8:12;]
225+
@test b[2][:action] == [7:11;]
226+
@test b[2][:reward] == [7:11;]
221227

222228
for (j,i) = enumerate(2:5)
223229
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
@@ -241,8 +247,8 @@
241247

242248
b = sample(s, eb)
243249
@test length(b) == 2
244-
@test length(b[1][:state]) == 5
245-
@test length(b[2][:state]) == 6
250+
@test length(b[1][:state]) == 4
251+
@test length(b[2][:state]) == 5
246252
@test !haskey(b[1], :action)
247253
end
248254
end

0 commit comments

Comments
 (0)