Skip to content

Commit 6e00c33

Browse files
committed
Merge branch 'EpisodeResetCondition' of https://github.com/HenriDeh/ReinforcementLearning.jl into EpisodeResetCondition
2 parents ae26ac2 + 6e8575c commit 6e00c33

File tree

13 files changed

+214
-378
lines changed

13 files changed

+214
-378
lines changed

src/ReinforcementLearningCore/src/policies/agent.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export Agent
22

33
using Base.Threads: @spawn
44

5-
import Functors
5+
using Functors: @functor
66

77
"""
88
Agent(;policy, trajectory)
@@ -20,7 +20,7 @@ mutable struct Agent{P,T} <: AbstractPolicy
2020
trajectory::T
2121
cache::NamedTuple # trajectory do not support partial inserting
2222

23-
function Agent(policy::P, trajectory::T, cache = NamedTuple()) where {P,T}
23+
function Agent(policy::P, trajectory::T, cache=NamedTuple()) where {P,T}
2424
agent = new{P,T}(policy, trajectory, cache)
2525
if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
2626
bind(trajectory, @spawn(optimise!(p, t)))
@@ -29,7 +29,7 @@ mutable struct Agent{P,T} <: AbstractPolicy
2929
end
3030
end
3131

32-
Agent(; policy, trajectory, cache = NamedTuple()) = Agent(policy, trajectory, cache)
32+
Agent(; policy, trajectory, cache=NamedTuple()) = Agent(policy, trajectory, cache)
3333

3434
RLBase.optimise!(agent::Agent) = optimise!(TrajectoryStyle(agent.trajectory), agent)
3535
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent) =
@@ -44,21 +44,20 @@ function RLBase.optimise!(policy::AbstractPolicy, trajectory::Trajectory)
4444
end
4545
end
4646

47-
Functors.functor(x::Agent) =
48-
(policy = x.policy,), y -> Agent(y.policy, x.trajectory, x.cache)
47+
@functor Agent (policy,)
4948

5049
# !!! TODO: In async scenarios, parameters of the policy may still be updating
5150
# (partially), which will result to incorrect action. This should be addressed
5251
# in Oolong.jl with a wrapper
5352
function (agent::Agent)(env::AbstractEnv)
5453
action = agent.policy(env)
55-
push!(agent.trajectory, (agent.cache..., action = action))
54+
push!(agent.trajectory, (agent.cache..., action=action))
5655
agent.cache = (;)
5756
action
5857
end
5958

6059
(agent::Agent)(::PreActStage, env::AbstractEnv) =
61-
agent.cache = (agent.cache..., state = state(env))
60+
agent.cache = (agent.cache..., state=state(env))
6261

6362
(agent::Agent)(::PostActStage, env::AbstractEnv) =
64-
agent.cache = (agent.cache..., reward = reward(env), terminal = is_terminated(env))
63+
agent.cache = (agent.cache..., reward=reward(env), terminal=is_terminated(env))
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export AbstractLearner, Approximator
22

33
import Flux
4-
import Functors
4+
using Functors: @functor
55

66
abstract type AbstractLearner end
77

@@ -12,9 +12,9 @@ Base.@kwdef mutable struct Approximator{M,O}
1212
optimiser::O
1313
end
1414

15-
Functors.functor(x::Approximator) = (model=x.model,), y -> Approximator(y.model, x.state)
15+
@functor Approximator (model,)
1616

17-
(A::Approximator)(x) = A.model(x)
17+
(A::Approximator)(args...) = A.model(args...)
1818

1919
RLBase.optimise!(A::Approximator, gs) =
2020
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@ export QBasedPolicy
33
include("learners.jl")
44
include("explorers/explorers.jl")
55

6-
import Functors
6+
using Functors: @functor
77

88
Base.@kwdef mutable struct QBasedPolicy{L,E} <: AbstractPolicy
99
learner::L
1010
explorer::E
1111
end
1212

13-
Functors.functor(x::QBasedPolicy) =
14-
(learner = x.learner,), y -> QBasedPolicy(y.learner, x.explorer)
13+
@functor QBasedPolicy (learner,)
1514

1615
(p::QBasedPolicy)(env) = p.explorer(p.learner(env), legal_action_space_mask(env))
1716

src/ReinforcementLearningCore/src/utils/networks.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import Functors
1+
using Functors: @functor
22
import Flux
33

4-
using Setfield: @set
5-
64
#####
75
# ActorCritic
86
#####
@@ -18,7 +16,7 @@ Base.@kwdef struct ActorCritic{A,C,O}
1816
critic::C
1917
end
2018

21-
Functors.@functor ActorCritic
19+
@functor ActorCritic
2220

2321
#####
2422
# GaussianNetwork
@@ -44,7 +42,7 @@ end
4442

4543
GaussianNetwork(pre, μ, logσ, normalizer=tanh) = GaussianNetwork(pre, μ, logσ, 0.0f0, Inf32, normalizer)
4644

47-
Functors.@functor GaussianNetwork
45+
@functor GaussianNetwork
4846

4947
"""
5048
This function is compatible with a multidimensional action space. When outputting an action, it uses the `normalizer` function to normalize it elementwise.
@@ -138,7 +136,7 @@ end
138136

139137
CovGaussianNetwork(pre, m, s) = CovGaussianNetwork(pre, m, s, tanh)
140138

141-
Functors.@functor CovGaussianNetwork
139+
@functor CovGaussianNetwork
142140

143141
"""
144142
(model::CovGaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
@@ -397,9 +395,9 @@ end
397395

398396
TwinNetwork(x; kw...) = TwinNetwork(; source=x, target=deepcopy(x), kw...)
399397

400-
Functors.functor(x::TwinNetwork) = (; source=x.source), y -> @set x.source = y.source
398+
@functor TwinNetwork (source,)
401399

402-
(model::TwinNetwork)(x) = model.source(x)
400+
(model::TwinNetwork)(args...) = model.source(args...)
403401

404402
function RLBase.optimise!(A::Approximator{<:TwinNetwork}, gs)
405403
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)

src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# title: JuliaRL\_IQN\_CartPole
44
# cover: assets/JuliaRL_IQN_CartPole.png
55
# description: IQN applied to CartPole
6-
# date: 2021-05-22
6+
# date: 2022-06-27
77
# author: "[Jun Tian](https://github.com/findmyway)"
88
# ---
99

@@ -12,18 +12,16 @@ using ReinforcementLearning
1212
using StableRNGs
1313
using Flux
1414
using Flux.Losses
15-
using CUDA
1615

1716
function RL.Experiment(
1817
::Val{:JuliaRL},
1918
::Val{:IQN},
2019
::Val{:CartPole},
21-
::Nothing;
22-
seed = 123,
20+
; seed=123
2321
)
2422
rng = StableRNG(seed)
25-
device_rng = CUDA.functional() ? CUDA.CURAND.RNG() : rng
26-
env = CartPoleEnv(; T = Float32, rng = rng)
23+
device_rng = rng
24+
env = CartPoleEnv(; T=Float32, rng=rng)
2725
ns, na = length(state(env)), length(action_space(env))
2826
init = glorot_uniform(rng)
2927
Nₑₘ = 16
@@ -32,51 +30,60 @@ function RL.Experiment(
3230

3331
nn_creator() =
3432
ImplicitQuantileNet(
35-
ψ = Dense(ns, n_hidden, relu; init = init),
36-
ϕ = Dense(Nₑₘ, n_hidden, relu; init = init),
37-
header = Dense(n_hidden, na; init = init),
33+
ψ=Dense(ns, n_hidden, relu; init=init),
34+
ϕ=Dense(Nₑₘ, n_hidden, relu; init=init),
35+
header=Dense(n_hidden, na; init=init),
3836
) |> gpu
3937

4038
agent = Agent(
41-
policy = QBasedPolicy(
42-
learner = IQNLearner(
43-
approximator = NeuralNetworkApproximator(
44-
model = nn_creator(),
45-
optimizer = ADAM(0.001),
39+
policy=QBasedPolicy(
40+
learner=IQNLearner(
41+
approximator=Approximator(
42+
model=TwinNetwork(
43+
ImplicitQuantileNet(
44+
ψ=Dense(ns, n_hidden, relu; init=init),
45+
ϕ=Dense(Nₑₘ, n_hidden, relu; init=init),
46+
header=Dense(n_hidden, na; init=init),
47+
),
48+
sync_freq=100
49+
),
50+
optimiser=ADAM(0.001),
4651
),
47-
target_approximator = NeuralNetworkApproximator(model = nn_creator()),
48-
κ = κ,
49-
N = 8,
50-
N′ = 8,
51-
Nₑₘ = Nₑₘ,
52-
K = 32,
53-
γ = 0.99f0,
54-
stack_size = nothing,
55-
batch_size = 32,
56-
update_horizon = 1,
57-
min_replay_history = 100,
58-
update_freq = 1,
59-
target_update_freq = 100,
60-
default_priority = 1.0f2,
61-
rng = rng,
62-
device_rng = device_rng,
52+
κ=κ,
53+
N=8,
54+
N′=8,
55+
Nₑₘ=Nₑₘ,
56+
K=32,
57+
γ=0.99f0,
58+
rng=rng,
59+
device_rng=device_rng,
6360
),
64-
explorer = EpsilonGreedyExplorer(
65-
kind = :exp,
66-
ϵ_stable = 0.01,
67-
decay_steps = 500,
68-
rng = rng,
61+
explorer=EpsilonGreedyExplorer(
62+
kind=:exp,
63+
ϵ_stable=0.01,
64+
decay_steps=500,
65+
rng=rng,
6966
),
7067
),
71-
trajectory = CircularArrayPSARTTrajectory(
72-
capacity = 1000,
73-
state = Vector{Float32} => (ns,),
74-
),
68+
trajectory=Trajectory(
69+
container=CircularArraySARTTraces(
70+
capacity=1000,
71+
state=Float32 => (ns,),
72+
),
73+
sampler=BatchSampler{SS′ART}(
74+
batch_size=32,
75+
rng=rng
76+
),
77+
controller=InsertSampleRatioController(
78+
threshold=100,
79+
n_inserted=-1
80+
)
81+
)
7582
)
7683

7784
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
7885
hook = TotalRewardPerEpisode()
79-
Experiment(agent, env, stop_condition, hook, "")
86+
Experiment(agent, env, stop_condition, hook)
8087
end
8188

8289

src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# title: JuliaRL\_REMDQN\_CartPole
33
# cover: assets/JuliaRL_REMDQN_CartPole.png
44
# description: REMDQN applied to CartPole
5-
# date: 2021-05-22
5+
# date: 2021-06-25
66
# author: "[Jun Tian](https://github.com/findmyway)"
77
# ---
88

@@ -16,61 +16,70 @@ function RL.Experiment(
1616
::Val{:JuliaRL},
1717
::Val{:REMDQN},
1818
::Val{:CartPole},
19-
::Nothing;
20-
seed = 123,
19+
; seed=123,
20+
ensemble_num=16
2121
)
2222
rng = StableRNG(seed)
2323

24-
env = CartPoleEnv(; T = Float32, rng = rng)
24+
env = CartPoleEnv(; T=Float32, rng=rng)
2525
ns, na = length(state(env)), length(action_space(env))
26-
ensemble_num = 16
26+
27+
n = 1
28+
γ = 0.99f0
2729

2830
agent = Agent(
29-
policy = QBasedPolicy(
30-
learner = REMDQNLearner(
31-
approximator = NeuralNetworkApproximator(
32-
model = Chain(
33-
## Multi-head method, please refer to "https://github.com/google-research/batch_rl/tree/b55ba35ebd2381199125dd77bfac9e9c59a64d74/batch_rl/multi_head".
34-
Dense(ns, 128, relu; init = glorot_uniform(rng)),
35-
Dense(128, 128, relu; init = glorot_uniform(rng)),
36-
Dense(128, na * ensemble_num; init = glorot_uniform(rng)),
37-
) |> gpu,
38-
optimizer = ADAM(),
39-
),
40-
target_approximator = NeuralNetworkApproximator(
41-
model = Chain(
42-
Dense(ns, 128, relu; init = glorot_uniform(rng)),
43-
Dense(128, 128, relu; init = glorot_uniform(rng)),
44-
Dense(128, na * ensemble_num; init = glorot_uniform(rng)),
45-
) |> gpu,
31+
policy=QBasedPolicy(
32+
learner=REMDQNLearner(
33+
approximator=Approximator(
34+
model=TwinNetwork(
35+
Chain(
36+
## Multi-head method, please refer to "https://github.com/google-research/batch_rl/tree/b55ba35ebd2381199125dd77bfac9e9c59a64d74/batch_rl/multi_head".
37+
Dense(ns, 128, relu; init=glorot_uniform(rng)),
38+
Dense(128, 128, relu; init=glorot_uniform(rng)),
39+
Dense(128, na * ensemble_num; init=glorot_uniform(rng)),
40+
),
41+
sync_freq=100
42+
),
43+
optimiser=ADAM(),
4644
),
47-
loss_func = huber_loss,
48-
stack_size = nothing,
49-
batch_size = 32,
50-
update_horizon = 1,
51-
min_replay_history = 100,
52-
update_freq = 1,
53-
target_update_freq = 100,
54-
ensemble_num = ensemble_num,
55-
ensemble_method = :rand,
56-
rng = rng,
45+
n=n,
46+
γ=γ,
47+
loss_func=huber_loss,
48+
ensemble_num=ensemble_num,
49+
ensemble_method=:rand,
50+
rng=rng,
5751
),
58-
explorer = EpsilonGreedyExplorer(
59-
kind = :exp,
60-
ϵ_stable = 0.01,
61-
decay_steps = 500,
62-
rng = rng,
52+
explorer=EpsilonGreedyExplorer(
53+
kind=:exp,
54+
ϵ_stable=0.01,
55+
decay_steps=500,
56+
rng=rng,
6357
),
6458
),
65-
trajectory = CircularArraySARTTrajectory(
66-
capacity = 1000,
67-
state = Vector{Float32} => (ns,),
68-
),
59+
trajectory=Trajectory(
60+
container=CircularArraySARTTraces(
61+
capacity=1000,
62+
state=Float32 => (ns,),
63+
),
64+
sampler=NStepBatchSampler{SS′ART}(
65+
n=n,
66+
γ=γ,
67+
batch_size=32,
68+
rng=rng
69+
),
70+
controller=InsertSampleRatioController(
71+
threshold=100,
72+
n_inserted=-1
73+
)
74+
)
6975
)
7076

7177
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
7278
hook = TotalRewardPerEpisode()
73-
Experiment(agent, env, stop_condition, hook, "")
79+
80+
## !!! note that REMDQN is used in offline RL
81+
## TODO: use DQN to collect experiences and then optimise the REMDQN
82+
Experiment(agent, env, stop_condition, hook)
7483
end
7584

7685
#+ tangle=false

src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_BasicDQN_CartPole.jl"))
1212
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_DQN_CartPole.jl"))
1313
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_PrioritizedDQN_CartPole.jl"))
1414
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_QRDQN_CartPole.jl"))
15+
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_REMDQN_CartPole.jl"))
16+
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_IQN_CartPole.jl"))
1517

1618
# dynamic loading environments
1719
function __init__() end

0 commit comments

Comments
 (0)