Skip to content

Commit 323793d

Browse files
authored
Merge pull request #40 from JuliaReinforcementLearning/episodecontroller
add Episode Controller
2 parents cb74a59 + 8f4248f commit 323793d

File tree

4 files changed

+79
-12
lines changed

4 files changed

+79
-12
lines changed

src/controllers.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export InsertSampleRatioController, AsyncInsertSampleRatioController
1+
export InsertSampleRatioController, AsyncInsertSampleRatioController, EpisodeSampleRatioController
22

33
"""
44
InsertSampleRatioController(;ratio=1., threshold=1)
@@ -15,18 +15,16 @@ Base.@kwdef mutable struct InsertSampleRatioController
1515
n_sampled::Int = 0
1616
end
1717

18-
function on_insert!(c::InsertSampleRatioController, n::Int)
18+
function on_insert!(c::InsertSampleRatioController, n::Int, ::Any)
1919
if n > 0
2020
c.n_inserted += n
2121
end
2222
end
2323

2424
function on_sample!(c::InsertSampleRatioController)
25-
if c.n_inserted >= c.threshold
26-
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
27-
c.n_sampled += 1
28-
return true
29-
end
25+
if c.n_inserted >= c.threshold && c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
26+
c.n_sampled += 1
27+
return true
3028
end
3129
return false
3230
end
@@ -59,3 +57,32 @@ function AsyncInsertSampleRatioController(
5957
Channel(ch_out_sz)
6058
)
6159
end
60+
61+
"""
62+
EpisodeSampleRatioController(;ratio=1., threshold=1)
63+
64+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
65+
episodes completed before sampling is allowed. The `ratio` balances the number of episodes and
66+
the number of samplings. For example a ratio of 1/10 will sample once every 10
67+
episodes in the trajectory. Currently only works for environemnts with terminal states.
68+
"""
69+
Base.@kwdef mutable struct EpisodeSampleRatioController
70+
ratio::Float64 = 1.0
71+
threshold::Int = 1
72+
n_episodes::Int = 0
73+
n_sampled::Int = 0
74+
end
75+
76+
function on_insert!(c::EpisodeSampleRatioController, n::Int, x::NamedTuple)
77+
if n > 0
78+
c.n_episodes += sum(x.terminal)
79+
end
80+
end
81+
82+
function on_sample!(c::EpisodeSampleRatioController)
83+
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
84+
c.n_sampled += 1
85+
return true
86+
end
87+
return false
88+
end

src/trajectory.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,18 @@ Base.setindex!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, v,
8787

8888
function Base.append!(t::Trajectory, x)
8989
append!(t.container, x)
90-
on_insert!(t.controller, length(x))
90+
on_insert!(t.controller, length(x), x)
9191
end
9292

9393
# !!! by default we assume `x` is a complete example which contains all the traces
9494
# When doing partial inserting, the result of undefined
9595
function Base.push!(t::Trajectory, x)
9696
push!(t.container, x)
97-
on_insert!(t)
97+
on_insert!(t, x)
9898
end
9999

100-
on_insert!(t::Trajectory) = on_insert!(t, 1)
101-
on_insert!(t::Trajectory, n::Int) = on_insert!(t.controller, n)
100+
on_insert!(t::Trajectory, x) = on_insert!(t, 1, x)
101+
on_insert!(t::Trajectory, n::Int, x) = on_insert!(t.controller, n, x)
102102

103103
#####
104104
# out

test/controllers.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import ReinforcementLearningTrajectories: on_insert!, on_sample!
2+
@testset "controllers.jl" begin
3+
@testset "EpisodeSampleRatioController" begin
4+
#push
5+
c = EpisodeSampleRatioController(ratio = 1/2, threshold = 5)
6+
for st in 1:50
7+
transition = (state = 1, action = 2, reward = 5., terminal = (st % 5 == 0))
8+
on_insert!(c, 1, transition)
9+
if st in 25:10:45
10+
@test on_sample!(c)
11+
@test !on_sample!(c)
12+
else
13+
@test !on_sample!(c)
14+
end
15+
end
16+
#append
17+
c = EpisodeSampleRatioController(ratio = 1/2, threshold = 5)
18+
for e in 1:20
19+
transitions = (state = ones(5), action = ones(5), reward = ones(5), terminal = [false, false, false, false, iseven(e)])
20+
on_insert!(c, length(first(transitions)), transitions)
21+
if e in 10:4:20
22+
@test on_sample!(c)
23+
@test !on_sample!(c)
24+
else
25+
@test !on_sample!(c)
26+
end
27+
end
28+
c = EpisodeSampleRatioController(ratio = 1/4, threshold = 5)
29+
for e in 1:10
30+
transitions = (state = ones(10), action = ones(10), reward = ones(10), terminal = [false, false, false, false, true, false, false, false, false, true])
31+
on_insert!(c, length(first(transitions)), transitions)
32+
if e in 3:2:10
33+
@test on_sample!(c)
34+
@test !on_sample!(c)
35+
else
36+
@test !on_sample!(c)
37+
end
38+
end
39+
end
40+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Adapt.adapt_storage(to::TestAdaptor, x) = CUDA.functional() ? CUDA.cu(x) : x
1414
include("traces.jl")
1515
include("common.jl")
1616
include("samplers.jl")
17+
include("controllers.jl")
1718
include("trajectories.jl")
1819
include("normalization.jl")
19-
include("samplers.jl")
2020
end

0 commit comments

Comments
 (0)