Skip to content

Commit 208cfb7

Browse files
Refactor TRPO and VPG with EpisodesSampler (#952)
* Change qpolicy default update stage * Add a docstring * qbasedpolicy dispatches on learner * default to nothing * update docs * bump versions to require RLTraj 0.3.3 * use EpisodesSampler in experiments * Require latest Zoo version * And bump Exp version * refactor VPG * rebump compats and versions * refactor TRPO * add cuDNN * use correct traj * activate tests... * include algos * use stack for dimensions agnosticity * use stack with trpo * move slow runtests * fix dimensions * comment back the algos * lower NFQ batchsize * Update src/ReinforcementLearningZoo/Project.toml Co-authored-by: Jeremiah <[email protected]> * Update src/ReinforcementLearningCore/Project.toml Co-authored-by: Jeremiah <[email protected]> * Update src/ReinforcementLearningExperiments/Project.toml Co-authored-by: Jeremiah <[email protected]> --------- Co-authored-by: Jeremiah <[email protected]>
1 parent f258a84 commit 208cfb7

File tree

10 files changed

+49
-29
lines changed

10 files changed

+49
-29
lines changed

src/ReinforcementLearningCore/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2525
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2626
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2727
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
28+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2829

2930
[compat]
3031
AbstractTrees = "0.3, 0.4"
@@ -42,7 +43,7 @@ Parsers = "2"
4243
ProgressMeter = "1"
4344
Reexport = "1"
4445
ReinforcementLearningBase = "0.12"
45-
ReinforcementLearningTrajectories = "^0.3.2"
46+
ReinforcementLearningTrajectories = "^0.3.3"
4647
StatsBase = "0.32, 0.33, 0.34"
4748
TimerOutputs = "0.5"
4849
UnicodePlots = "1.3, 2, 3"

src/ReinforcementLearningExperiments/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
1414
ReinforcementLearningZoo = "d607f57d-ee1e-4ba7-bcf2-7734c1e31854"
1515
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1616
Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9"
17+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
1718

1819
[compat]
1920
Distributions = "0.25"
@@ -22,7 +23,7 @@ Reexport = "1"
2223
ReinforcementLearningBase = "0.12"
2324
ReinforcementLearningCore = "0.12, 0.13"
2425
ReinforcementLearningEnvironments = "0.8"
25-
ReinforcementLearningZoo = "0.7, 0.8"
26+
ReinforcementLearningZoo = "^0.8.3"
2627
StableRNGs = "1"
2728
Weave = "0.10"
2829
julia = "1.9"

src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_NFQ_CartPole.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function RLCore.Experiment(
5656
action=Float32 => (na,),
5757
),
5858
sampler=BatchSampler{SS′ART}(
59-
batch_size=10_000,
59+
batch_size=128,
6060
rng=rng
6161
),
6262
controller=InsertSampleRatioController(

src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_TRPO_CartPole.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ function RLCore.Experiment(
4545
),
4646
rng=rng,
4747
),
48-
trajectory=Trajectory(container=Episode(ElasticArraySARTTraces(state=Float32 => (ns,))))
48+
trajectory=Trajectory(container=CircularArraySARTSTraces(capacity = 10000, state=Float32 => (ns,)), sampler = EpisodesSampler(), controller = InsertSampleRatioController(ratio = 1/10000))
49+
#Note: an EpisodeSamplerRatioController would be more adapted here.
4950
)
5051
stop_condition = StopAfterEpisode(100, is_show_progress=!haskey(ENV, "CI"))
5152

src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ function RLCore.Experiment(
4646
γ=0.99f0,
4747
rng=rng,
4848
),
49-
trajectory=Trajectory(container=Episode(ElasticArraySARTTraces(state=Float32 => (ns,))))
49+
trajectory=Trajectory(container=CircularArraySARTSTraces(capacity = 10000, state=Float32 => (ns,)), sampler = EpisodesSampler(), controller = InsertSampleRatioController(ratio = 1/10000))
50+
#Note: an EpisodeSamplerRatioController would be more adapted here.
5051
)
5152
stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI"))
5253

src/ReinforcementLearningExperiments/test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@ using CUDA
33

44
CUDA.allowscalar(false)
55

6-
run(E`JuliaRL_NFQ_CartPole`)
76
run(E`JuliaRL_BasicDQN_CartPole`)
87
run(E`JuliaRL_DQN_CartPole`)
8+
run(E`JuliaRL_NFQ_CartPole`)
99
# run(E`JuliaRL_PrioritizedDQN_CartPole`)
1010
run(E`JuliaRL_QRDQN_CartPole`)
1111
run(E`JuliaRL_REMDQN_CartPole`)
1212
run(E`JuliaRL_IQN_CartPole`)
1313
run(E`JuliaRL_Rainbow_CartPole`)
14-
# run(E`JuliaRL_VPG_CartPole`)
14+
#run(E`JuliaRL_VPG_CartPole`)
15+
#run(E`JuliaRL_TRPO_CartPole`)
1516
run(E`JuliaRL_MPODiscrete_CartPole`)
1617
run(E`JuliaRL_MPOContinuous_CartPole`)
1718
run(E`JuliaRL_MPOCovariance_CartPole`)

src/ReinforcementLearningZoo/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1717
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
1818
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1919
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
20+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2021

2122
[compat]
2223
CUDA = "4"
@@ -28,7 +29,7 @@ LogExpFunctions = "0.3"
2829
NNlib = "0.8, 0.9"
2930
Optim = "1"
3031
ReinforcementLearningBase = "0.12"
31-
ReinforcementLearningCore = "0.12, 0.13"
32+
ReinforcementLearningCore = "^0.12.3, 0.13"
3233
StatsBase = "0.33, 0.34"
3334
Zygote = "0.6"
3435
julia = "1.9"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# include("run.jl")
22
include("util.jl")
3-
# include("vpg.jl")
3+
#include("vpg.jl")
44
# include("A2C.jl")
55
# include("ppo.jl")
66
# include("A2CGAE.jl")
@@ -10,5 +10,5 @@ include("util.jl")
1010
# include("sac.jl")
1111
# include("maddpg.jl")
1212
# include("vmpo.jl")
13-
# include("trpo.jl")
13+
#include("trpo.jl")
1414
include("mpo.jl")

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,23 @@ function Base.push!(p::Agent{<:TRPO}, ::PostEpisodeStage, env::AbstractEnv)
3939
empty!(p.trajectory.container)
4040
end
4141

42-
RLBase.optimise!(::Agent{<:TRPO}, ::PostActStage) = nothing
43-
44-
function RLBase.optimise!::TRPO, ::PostActStage, episode::Episode)
45-
gain = discount_rewards(episode[:reward][:], π.γ)
46-
for inds in Iterators.partition(shuffle.rng, 1:length(episode)), π.batch_size)
47-
RLBase.optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
42+
function RLBase.optimise!(p::TRPO, ::PostEpisodeStage, trajectory::Trajectory)
43+
has_optimized = false
44+
for batch in trajectory #batch is a vector of Episode
45+
gains = vcat(discount_rewards(ep[:reward], p.γ) for ep in batch)
46+
states = reduce(ep[:state] for ep in batch) do s, s2
47+
cat(s,s2, dims = ndims(first(batch[:state])))
48+
end
49+
actions = reduce(ep[:action] for ep in batch) do s, s2
50+
cat(s, s2, dims = ndims(first(batch[:action])))
51+
end
52+
for inds in Iterators.partition(shuffle(p.rng, eachindex(gains)), p.batch_size)
53+
RLBase.optimise!(p, (state=selectdim(states,ndims(states),inds), action=selectdim(actions,ndims(actions),inds), gain=gains[inds]))
54+
end
55+
has_optimized = true
4856
end
57+
has_optimized && empty!(trajectory.container)
58+
return nothing
4959
end
5060

5161
function RLBase.optimise!(p::TRPO, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)})

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,30 @@ function RLBase.plan!(π::VPG, env::AbstractEnv)
3131
end
3232

3333
function RLBase.optimise!(p::VPG, ::PostEpisodeStage, trajectory::Trajectory)
34-
trajectory.container[] = true
35-
for batch in trajectory
36-
RLBase.optimise!(p, batch)
37-
end
38-
empty!(trajectory.container)
39-
end
40-
41-
function RLBase.optimise!::VPG, episode::Episode)
42-
gain = discount_rewards(episode[:reward][:], π.γ)
43-
for inds in Iterators.partition(shuffle.rng, 1:length(episode)), π.batch_size)
44-
RLBase.optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
34+
has_optimized = false
35+
for batch in trajectory #batch is a vector of Episode
36+
gains = vcat(discount_rewards(ep[:reward], p.γ) for ep in batch)
37+
states = reduce(ep[:state] for ep in batch) do s, s2
38+
cat(s,s2, dims = ndims(first(batch[:state])))
39+
end
40+
actions = reduce(ep[:action] for ep in batch) do s, s2
41+
cat(s, s2, dims = ndims(first(batch[:action])))
42+
end
43+
for inds in Iterators.partition(shuffle(p.rng, eachindex(gains)), p.batch_size)
44+
RLBase.optimise!(p, (state=selectdim(states,ndims(states),inds), action=selectdim(actions,ndims(actions),inds), gain=gains[inds]))
45+
end
46+
has_optimized = true
4547
end
48+
has_optimized && empty!(trajectory.container)
49+
return nothing
4650
end
4751

4852
function RLBase.optimise!(p::VPG, batch::NamedTuple{(:state, :action, :gain)})
4953
A = p.approximator
5054
B = p.baseline
51-
s, a, g = map(Array, batch) # !!! FIXME
55+
s, a, g = batch[:state], batch[:action], batch[:gain]
5256
local δ
53-
57+
println(s)
5458
if isnothing(B)
5559
δ = normalise(g)
5660
loss = 0

0 commit comments

Comments
 (0)