Skip to content

Commit 83310a9

Browse files
authored
add REMDQN (#708)
1 parent fc74394 commit 83310a9

File tree

11 files changed

+109
-174
lines changed

11 files changed

+109
-174
lines changed

src/ReinforcementLearningCore/src/policies/agent.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export Agent
22

33
using Base.Threads: @spawn
44

5-
import Functors
5+
using Functors: @functor
66

77
"""
88
Agent(;policy, trajectory)
@@ -20,7 +20,7 @@ mutable struct Agent{P,T} <: AbstractPolicy
2020
trajectory::T
2121
cache::NamedTuple # trajectory do not support partial inserting
2222

23-
function Agent(policy::P, trajectory::T, cache = NamedTuple()) where {P,T}
23+
function Agent(policy::P, trajectory::T, cache=NamedTuple()) where {P,T}
2424
agent = new{P,T}(policy, trajectory, cache)
2525
if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
2626
bind(trajectory, @spawn(optimise!(p, t)))
@@ -29,7 +29,7 @@ mutable struct Agent{P,T} <: AbstractPolicy
2929
end
3030
end
3131

32-
Agent(; policy, trajectory, cache = NamedTuple()) = Agent(policy, trajectory, cache)
32+
Agent(; policy, trajectory, cache=NamedTuple()) = Agent(policy, trajectory, cache)
3333

3434
RLBase.optimise!(agent::Agent) = optimise!(TrajectoryStyle(agent.trajectory), agent)
3535
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent) =
@@ -44,21 +44,20 @@ function RLBase.optimise!(policy::AbstractPolicy, trajectory::Trajectory)
4444
end
4545
end
4646

47-
Functors.functor(x::Agent) =
48-
(policy = x.policy,), y -> Agent(y.policy, x.trajectory, x.cache)
47+
@functor Agent (policy,)
4948

5049
# !!! TODO: In async scenarios, parameters of the policy may still be updating
5150
# (partially), which will result to incorrect action. This should be addressed
5251
# in Oolong.jl with a wrapper
5352
function (agent::Agent)(env::AbstractEnv)
5453
action = agent.policy(env)
55-
push!(agent.trajectory, (agent.cache..., action = action))
54+
push!(agent.trajectory, (agent.cache..., action=action))
5655
agent.cache = (;)
5756
action
5857
end
5958

6059
(agent::Agent)(::PreActStage, env::AbstractEnv) =
61-
agent.cache = (agent.cache..., state = state(env))
60+
agent.cache = (agent.cache..., state=state(env))
6261

6362
(agent::Agent)(::PostActStage, env::AbstractEnv) =
64-
agent.cache = (agent.cache..., reward = reward(env), terminal = is_terminated(env))
63+
agent.cache = (agent.cache..., reward=reward(env), terminal=is_terminated(env))

src/ReinforcementLearningCore/src/policies/learners.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export AbstractLearner, Approximator
22

33
import Flux
4-
import Functors
4+
using Functors: @functor
55

66
abstract type AbstractLearner end
77

@@ -12,7 +12,7 @@ Base.@kwdef mutable struct Approximator{M,O}
1212
optimiser::O
1313
end
1414

15-
Functors.functor(x::Approximator) = (model=x.model,), y -> Approximator(y.model, x.state)
15+
@functor Approximator (model,)
1616

1717
(A::Approximator)(x) = A.model(x)
1818

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@ export QBasedPolicy
33
include("learners.jl")
44
include("explorers/explorers.jl")
55

6-
import Functors
6+
using Functors: @functor
77

88
Base.@kwdef mutable struct QBasedPolicy{L,E} <: AbstractPolicy
99
learner::L
1010
explorer::E
1111
end
1212

13-
Functors.functor(x::QBasedPolicy) =
14-
(learner = x.learner,), y -> QBasedPolicy(y.learner, x.explorer)
13+
@functor QBasedPolicy (learner,)
1514

1615
(p::QBasedPolicy)(env) = p.explorer(p.learner(env), legal_action_space_mask(env))
1716

src/ReinforcementLearningCore/src/utils/networks.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import Functors
1+
using Functors: @functor
22
import Flux
33

4-
using Setfield: @set
5-
64
#####
75
# ActorCritic
86
#####
@@ -18,7 +16,7 @@ Base.@kwdef struct ActorCritic{A,C,O}
1816
critic::C
1917
end
2018

21-
Functors.@functor ActorCritic
19+
@functor ActorCritic
2220

2321
#####
2422
# GaussianNetwork
@@ -44,7 +42,7 @@ end
4442

4543
GaussianNetwork(pre, μ, logσ, normalizer=tanh) = GaussianNetwork(pre, μ, logσ, 0.0f0, Inf32, normalizer)
4644

47-
Functors.@functor GaussianNetwork
45+
@functor GaussianNetwork
4846

4947
"""
5048
This function is compatible with a multidimensional action space. When outputting an action, it uses the `normalizer` function to normalize it elementwise.
@@ -138,7 +136,7 @@ end
138136

139137
CovGaussianNetwork(pre, m, s) = CovGaussianNetwork(pre, m, s, tanh)
140138

141-
Functors.@functor CovGaussianNetwork
139+
@functor CovGaussianNetwork
142140

143141
"""
144142
(model::CovGaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
@@ -397,7 +395,7 @@ end
397395

398396
TwinNetwork(x; kw...) = TwinNetwork(; source=x, target=deepcopy(x), kw...)
399397

400-
Functors.functor(x::TwinNetwork) = (; source=x.source), y -> @set x.source = y.source
398+
@functor TwinNetwork (source,)
401399

402400
(model::TwinNetwork)(x) = model.source(x)
403401

src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# title: JuliaRL\_REMDQN\_CartPole
33
# cover: assets/JuliaRL_REMDQN_CartPole.png
44
# description: REMDQN applied to CartPole
5-
# date: 2021-05-22
5+
# date: 2021-06-25
66
# author: "[Jun Tian](https://github.com/findmyway)"
77
# ---
88

@@ -16,61 +16,70 @@ function RL.Experiment(
1616
::Val{:JuliaRL},
1717
::Val{:REMDQN},
1818
::Val{:CartPole},
19-
::Nothing;
20-
seed = 123,
19+
; seed=123,
20+
ensemble_num=16
2121
)
2222
rng = StableRNG(seed)
2323

24-
env = CartPoleEnv(; T = Float32, rng = rng)
24+
env = CartPoleEnv(; T=Float32, rng=rng)
2525
ns, na = length(state(env)), length(action_space(env))
26-
ensemble_num = 16
26+
27+
n = 1
28+
γ = 0.99f0
2729

2830
agent = Agent(
29-
policy = QBasedPolicy(
30-
learner = REMDQNLearner(
31-
approximator = NeuralNetworkApproximator(
32-
model = Chain(
33-
## Multi-head method, please refer to "https://github.com/google-research/batch_rl/tree/b55ba35ebd2381199125dd77bfac9e9c59a64d74/batch_rl/multi_head".
34-
Dense(ns, 128, relu; init = glorot_uniform(rng)),
35-
Dense(128, 128, relu; init = glorot_uniform(rng)),
36-
Dense(128, na * ensemble_num; init = glorot_uniform(rng)),
37-
) |> gpu,
38-
optimizer = ADAM(),
39-
),
40-
target_approximator = NeuralNetworkApproximator(
41-
model = Chain(
42-
Dense(ns, 128, relu; init = glorot_uniform(rng)),
43-
Dense(128, 128, relu; init = glorot_uniform(rng)),
44-
Dense(128, na * ensemble_num; init = glorot_uniform(rng)),
45-
) |> gpu,
31+
policy=QBasedPolicy(
32+
learner=REMDQNLearner(
33+
approximator=Approximator(
34+
model=TwinNetwork(
35+
Chain(
36+
## Multi-head method, please refer to "https://github.com/google-research/batch_rl/tree/b55ba35ebd2381199125dd77bfac9e9c59a64d74/batch_rl/multi_head".
37+
Dense(ns, 128, relu; init=glorot_uniform(rng)),
38+
Dense(128, 128, relu; init=glorot_uniform(rng)),
39+
Dense(128, na * ensemble_num; init=glorot_uniform(rng)),
40+
),
41+
sync_freq=100
42+
),
43+
optimiser=ADAM(),
4644
),
47-
loss_func = huber_loss,
48-
stack_size = nothing,
49-
batch_size = 32,
50-
update_horizon = 1,
51-
min_replay_history = 100,
52-
update_freq = 1,
53-
target_update_freq = 100,
54-
ensemble_num = ensemble_num,
55-
ensemble_method = :rand,
56-
rng = rng,
45+
n=n,
46+
γ=γ,
47+
loss_func=huber_loss,
48+
ensemble_num=ensemble_num,
49+
ensemble_method=:rand,
50+
rng=rng,
5751
),
58-
explorer = EpsilonGreedyExplorer(
59-
kind = :exp,
60-
ϵ_stable = 0.01,
61-
decay_steps = 500,
62-
rng = rng,
52+
explorer=EpsilonGreedyExplorer(
53+
kind=:exp,
54+
ϵ_stable=0.01,
55+
decay_steps=500,
56+
rng=rng,
6357
),
6458
),
65-
trajectory = CircularArraySARTTrajectory(
66-
capacity = 1000,
67-
state = Vector{Float32} => (ns,),
68-
),
59+
trajectory=Trajectory(
60+
container=CircularArraySARTTraces(
61+
capacity=1000,
62+
state=Float32 => (ns,),
63+
),
64+
sampler=NStepBatchSampler{SS′ART}(
65+
n=n,
66+
γ=γ,
67+
batch_size=32,
68+
rng=rng
69+
),
70+
controller=InsertSampleRatioController(
71+
threshold=100,
72+
n_inserted=-1
73+
)
74+
)
6975
)
7076

7177
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
7278
hook = TotalRewardPerEpisode()
73-
Experiment(agent, env, stop_condition, hook, "")
79+
80+
## !!! note that REMDQN is used in offline RL
81+
## TODO: use DQN to collect experiences and then optimise the REMDQN
82+
Experiment(agent, env, stop_condition, hook)
7483
end
7584

7685
#+ tangle=false

src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_BasicDQN_CartPole.jl"))
1212
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_DQN_CartPole.jl"))
1313
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_PrioritizedDQN_CartPole.jl"))
1414
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_QRDQN_CartPole.jl"))
15+
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_REMDQN_CartPole.jl"))
1516

1617
# dynamic loading environments
1718
function __init__() end

src/ReinforcementLearningExperiments/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ run(E`JuliaRL_BasicDQN_CartPole`)
77
run(E`JuliaRL_DQN_CartPole`)
88
run(E`JuliaRL_PrioritizedDQN_CartPole`)
99
run(E`JuliaRL_QRDQN_CartPole`)
10+
run(E`JuliaRL_REMDQN_CartPole`)
1011
# run(E`JuliaRL_BC_CartPole`)
1112
# run(E`JuliaRL_Rainbow_CartPole`)
12-
# run(E`JuliaRL_REMDQN_CartPole`)
1313
# run(E`JuliaRL_IQN_CartPole`)
1414
# run(E`JuliaRL_VMPO_CartPole`)
1515
# run(E`JuliaRL_VPG_CartPole`)

src/ReinforcementLearningZoo/src/algorithms/dqns/basic_dqn.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ export BasicDQNLearner
33
using Flux: gradient, params
44
using Zygote: ignore
55
using Setfield: @set
6-
7-
import Functors
6+
using Functors: @functor
87

98
"""
109
BasicDQNLearner(;kwargs...)
@@ -32,7 +31,7 @@ Base.@kwdef mutable struct BasicDQNLearner{Q} <: AbstractLearner
3231
loss::Float32 = 0.0f0
3332
end
3433

35-
Functors.functor(x::BasicDQNLearner) = (Q=x.approximator,), y -> @set x.approximator = y.Q
34+
@functor BasicDQNLearner (approximator,)
3635

3736
(L::BasicDQNLearner)(s::AbstractArray) = L.approximator(s)
3837

src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
export DQNLearner
22

3-
using Setfield: @set
43
using Random: AbstractRNG, GLOBAL_RNG
5-
import Functors
4+
using Functors: @functor
65

76
Base.@kwdef mutable struct DQNLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner
87
approximator::A
@@ -17,7 +16,7 @@ end
1716

1817
(L::DQNLearner)(s::AbstractArray) = L.approximator(s)
1918

20-
Functors.functor(x::DQNLearner) = (; approximator=x.approximator), y -> @set x.approximator = y.approximator
19+
@functor DQNLearner (approximator,)
2120

2221
function RLBase.optimise!(learner::DQNLearner, batch::Union{NamedTuple{SS′ART},NamedTuple{SS′L′ART}})
2322
A = learner.approximator

src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ include("basic_dqn.jl")
22
include("dqn.jl")
33
include("prioritized_dqn.jl")
44
include("qr_dqn.jl")
5-
# include("rem_dqn.jl")
5+
include("rem_dqn.jl")
66
# include("rainbow.jl")
77
# include("iqn.jl")
88
# include("common.jl")

0 commit comments

Comments
 (0)