Skip to content

Commit c9f2ebe

Browse files
Merge pull request #44 from JuliaReinforcementLearning/jpsl/traj-fixes
Fix interaction between CircularPrioritizedTraces and NStepBatchSampler
2 parents e3179da + 9ee7cda commit c9f2ebe

File tree

5 files changed

+57
-6
lines changed

5 files changed

+57
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.2"
3+
version = "0.2.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/common/CircularPrioritizedTraces.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,6 @@ function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
5959
end
6060
end
6161

62-
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
62+
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
63+
64+
capacity(t::CircularPrioritizedTraces) = ReinforcementLearningTrajectories.capacity(t.traces)

src/episodes.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,21 @@ struct PartialNamedTuple{T}
3939
namedtuple::T
4040
end
4141

42+
# Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases
43+
function is_capacity_plus_one(traces::AbstractTraces)
44+
if any(t->t isa MultiplexTraces, traces.traces)
45+
# MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity
46+
return true
47+
elseif traces isa CircularPrioritizedTraces
48+
# CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity
49+
return true
50+
else
51+
false
52+
end
53+
end
54+
4255
function EpisodesBuffer(traces::AbstractTraces)
43-
cap = any(t->t isa MultiplexTraces, traces.traces) ? capacity(traces) + 1 : capacity(traces)
56+
cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces)
4457
@assert isempty(traces) "EpisodesBuffer must be initialized with empty traces."
4558
if !isinf(cap)
4659
legalinds = CircularBuffer{Bool}(cap)

src/samplers.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,12 @@ function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
209209
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
210210
end
211211

212-
function StatsBase.sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where {names}
213-
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
212+
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
213+
t = e.traces
214+
st = deepcopy(t.priorities)
215+
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
216+
inds, priorities = rand(s.rng, st, s.batch_size)
217+
214218
merge(
215219
(key=t.keys[inds], priority=priorities),
216220
StatsBase.sample(s, t.traces, Val(names), inds)

test/samplers.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,36 @@ end
136136
@test xs3.reward[2] 5 + γ * (6 + γ * 7)
137137
@test xs3.reward[3] 7 + γ * (8 + γ * 9)
138138
end
139-
#! format: on
139+
#! format: on
140+
141+
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
142+
n=1
143+
γ=0.99f0
144+
145+
t = Trajectory(
146+
container=CircularPrioritizedTraces(
147+
CircularArraySARTSTraces(
148+
capacity=5,
149+
state=Float32 => (4,),
150+
);
151+
default_priority=100.0f0
152+
),
153+
sampler=NStepBatchSampler{SS′ART}(
154+
n=n,
155+
γ=γ,
156+
batch_size=32,
157+
),
158+
controller=InsertSampleRatioController(
159+
threshold=100,
160+
n_inserted=-1
161+
)
162+
)
163+
164+
push!(t, (state = 1, action = true))
165+
for i = 1:9
166+
push!(t, (state = i+1, action = true, reward = i, terminal = false))
167+
end
168+
169+
b = RLTrajectories.StatsBase.sample(t)
170+
@test haskey(b, :priority)
171+
end

0 commit comments

Comments
 (0)