Skip to content

Commit 4acf579

Browse files
authored
add IQN (#710)
1 parent 83310a9 commit 4acf579

File tree

7 files changed

+105
-204
lines changed

7 files changed

+105
-204
lines changed

src/ReinforcementLearningCore/src/policies/learners.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414

1515
@functor Approximator (model,)
1616

17-
(A::Approximator)(x) = A.model(x)
17+
(A::Approximator)(args...) = A.model(args...)
1818

1919
RLBase.optimise!(A::Approximator, gs) =
2020
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)

src/ReinforcementLearningCore/src/utils/networks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ TwinNetwork(x; kw...) = TwinNetwork(; source=x, target=deepcopy(x), kw...)
397397

398398
@functor TwinNetwork (source,)
399399

400-
(model::TwinNetwork)(x) = model.source(x)
400+
(model::TwinNetwork)(args...) = model.source(args...)
401401

402402
function RLBase.optimise!(A::Approximator{<:TwinNetwork}, gs)
403403
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)

src/ReinforcementLearningExperiments/deps/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# title: JuliaRL\_IQN\_CartPole
44
# cover: assets/JuliaRL_IQN_CartPole.png
55
# description: IQN applied to CartPole
6-
# date: 2021-05-22
6+
# date: 2022-06-27
77
# author: "[Jun Tian](https://github.com/findmyway)"
88
# ---
99

@@ -12,18 +12,16 @@ using ReinforcementLearning
1212
using StableRNGs
1313
using Flux
1414
using Flux.Losses
15-
using CUDA
1615

1716
function RL.Experiment(
1817
::Val{:JuliaRL},
1918
::Val{:IQN},
2019
::Val{:CartPole},
21-
::Nothing;
22-
seed = 123,
20+
; seed=123
2321
)
2422
rng = StableRNG(seed)
25-
device_rng = CUDA.functional() ? CUDA.CURAND.RNG() : rng
26-
env = CartPoleEnv(; T = Float32, rng = rng)
23+
device_rng = rng
24+
env = CartPoleEnv(; T=Float32, rng=rng)
2725
ns, na = length(state(env)), length(action_space(env))
2826
init = glorot_uniform(rng)
2927
Nₑₘ = 16
@@ -32,51 +30,60 @@ function RL.Experiment(
3230

3331
nn_creator() =
3432
ImplicitQuantileNet(
35-
ψ = Dense(ns, n_hidden, relu; init = init),
36-
ϕ = Dense(Nₑₘ, n_hidden, relu; init = init),
37-
header = Dense(n_hidden, na; init = init),
33+
ψ=Dense(ns, n_hidden, relu; init=init),
34+
ϕ=Dense(Nₑₘ, n_hidden, relu; init=init),
35+
header=Dense(n_hidden, na; init=init),
3836
) |> gpu
3937

4038
agent = Agent(
41-
policy = QBasedPolicy(
42-
learner = IQNLearner(
43-
approximator = NeuralNetworkApproximator(
44-
model = nn_creator(),
45-
optimizer = ADAM(0.001),
39+
policy=QBasedPolicy(
40+
learner=IQNLearner(
41+
approximator=Approximator(
42+
model=TwinNetwork(
43+
ImplicitQuantileNet(
44+
ψ=Dense(ns, n_hidden, relu; init=init),
45+
ϕ=Dense(Nₑₘ, n_hidden, relu; init=init),
46+
header=Dense(n_hidden, na; init=init),
47+
),
48+
sync_freq=100
49+
),
50+
optimiser=ADAM(0.001),
4651
),
47-
target_approximator = NeuralNetworkApproximator(model = nn_creator()),
48-
κ = κ,
49-
N = 8,
50-
N′ = 8,
51-
Nₑₘ = Nₑₘ,
52-
K = 32,
53-
γ = 0.99f0,
54-
stack_size = nothing,
55-
batch_size = 32,
56-
update_horizon = 1,
57-
min_replay_history = 100,
58-
update_freq = 1,
59-
target_update_freq = 100,
60-
default_priority = 1.0f2,
61-
rng = rng,
62-
device_rng = device_rng,
52+
κ=κ,
53+
N=8,
54+
N′=8,
55+
Nₑₘ=Nₑₘ,
56+
K=32,
57+
γ=0.99f0,
58+
rng=rng,
59+
device_rng=device_rng,
6360
),
64-
explorer = EpsilonGreedyExplorer(
65-
kind = :exp,
66-
ϵ_stable = 0.01,
67-
decay_steps = 500,
68-
rng = rng,
61+
explorer=EpsilonGreedyExplorer(
62+
kind=:exp,
63+
ϵ_stable=0.01,
64+
decay_steps=500,
65+
rng=rng,
6966
),
7067
),
71-
trajectory = CircularArrayPSARTTrajectory(
72-
capacity = 1000,
73-
state = Vector{Float32} => (ns,),
74-
),
68+
trajectory=Trajectory(
69+
container=CircularArraySARTTraces(
70+
capacity=1000,
71+
state=Float32 => (ns,),
72+
),
73+
sampler=BatchSampler{SS′ART}(
74+
batch_size=32,
75+
rng=rng
76+
),
77+
controller=InsertSampleRatioController(
78+
threshold=100,
79+
n_inserted=-1
80+
)
81+
)
7582
)
7683

7784
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
7885
hook = TotalRewardPerEpisode()
79-
Experiment(agent, env, stop_condition, hook, "")
86+
Experiment(agent, env, stop_condition, hook)
8087
end
8188

8289

src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_DQN_CartPole.jl"))
1313
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_PrioritizedDQN_CartPole.jl"))
1414
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_QRDQN_CartPole.jl"))
1515
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_REMDQN_CartPole.jl"))
16+
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_IQN_CartPole.jl"))
1617

1718
# dynamic loading environments
1819
function __init__() end

src/ReinforcementLearningExperiments/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ run(E`JuliaRL_DQN_CartPole`)
88
run(E`JuliaRL_PrioritizedDQN_CartPole`)
99
run(E`JuliaRL_QRDQN_CartPole`)
1010
run(E`JuliaRL_REMDQN_CartPole`)
11+
run(E`JuliaRL_IQN_CartPole`)
1112
# run(E`JuliaRL_BC_CartPole`)
1213
# run(E`JuliaRL_Rainbow_CartPole`)
13-
# run(E`JuliaRL_IQN_CartPole`)
1414
# run(E`JuliaRL_VMPO_CartPole`)
1515
# run(E`JuliaRL_VPG_CartPole`)
1616
# run(E`JuliaRL_BasicDQN_MountainCar`)

src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ include("dqn.jl")
33
include("prioritized_dqn.jl")
44
include("qr_dqn.jl")
55
include("rem_dqn.jl")
6+
include("iqn.jl")
67
# include("rainbow.jl")
7-
# include("iqn.jl")
88
# include("common.jl")

0 commit comments

Comments
 (0)