Skip to content

Commit 4e5d258

Browse files
authored
add MADDPG algorithm (#444)
* add maddpg * add experiment * update cspell.json * update the algo
1 parent e460aa2 commit 4e5d258

File tree

5 files changed

+273
-2
lines changed

5 files changed

+273
-2
lines changed

.cspell/cspell.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@
120120
"Norouzi",
121121
"gzopen",
122122
"turbulences",
123-
"Decompressor"
123+
"Decompressor",
124+
"MADDPG"
124125
],
125126
"ignoreWords": [],
126127
"minWordLength": 5,
@@ -143,4 +144,4 @@
143144
"\\{%.*%\\}", // liquid syntax
144145
"/^\\s*```[\\s\\S]*?^\\s*```/gm" // Another attempt at markdown code blocks. https://github.com/streetsidesoftware/vscode-spell-checker/issues/202#issuecomment-377477473
145146
]
146-
}
147+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# ---
2+
# title: JuliaRL\_MADDPG\_KuhnPoker
3+
# cover: assets/JuliaRL_MADDPG_KuhnPoker.png
4+
# description: MADDPG applied to KuhnPoker
5+
# date: 2021-08-09
6+
# author: "[Peter Chen](https://github.com/peterchen96)"
7+
# ---
8+
9+
#+ tangle=true
10+
using ReinforcementLearning
11+
using StableRNGs
12+
using Flux
13+
using IntervalSets
14+
15+
mutable struct ResultNEpisode <: AbstractHook
16+
eval_freq::Int
17+
episode_counter::Int
18+
episode::Vector{Int}
19+
results::Vector{Float64}
20+
end
21+
22+
function (hook::ResultNEpisode)(::PostEpisodeStage, policy, env)
23+
hook.episode_counter += 1
24+
if hook.episode_counter % hook.eval_freq == 0
25+
push!(hook.episode, hook.episode_counter)
26+
push!(hook.results, reward(env, 1))
27+
end
28+
end
29+
30+
function RL.Experiment(
31+
::Val{:JuliaRL},
32+
::Val{:MADDPG},
33+
::Val{:KuhnPoker},
34+
::Nothing;
35+
seed=123,
36+
)
37+
rng = StableRNG(seed)
38+
env = KuhnPokerEnv()
39+
wrapped_env = ActionTransformedEnv(
40+
StateTransformedEnv(
41+
env;
42+
state_mapping = s -> [findfirst(==(s), state_space(env))],
43+
state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)]
44+
),
45+
## add a dummy action for the other agent.
46+
action_mapping = x -> length(x) == 1 ? x : Int(x[current_player(env)] + 1),
47+
)
48+
ns, na = 1, 1
49+
n_players = 2
50+
51+
init = glorot_uniform(rng)
52+
53+
create_actor() = Chain(
54+
Dense(ns, 64, relu; init = init),
55+
Dense(64, 64, relu; init = init),
56+
Dense(64, na, tanh; init = init),
57+
)
58+
59+
create_critic() = Chain(
60+
Dense(n_players * ns + n_players * na, 64, relu; init = init),
61+
Dense(64, 64, relu; init = init),
62+
Dense(64, 1; init = init),
63+
)
64+
65+
66+
policy = DDPGPolicy(
67+
behavior_actor = NeuralNetworkApproximator(
68+
model = create_actor(),
69+
optimizer = ADAM(),
70+
),
71+
behavior_critic = NeuralNetworkApproximator(
72+
model = create_critic(),
73+
optimizer = ADAM(),
74+
),
75+
target_actor = NeuralNetworkApproximator(
76+
model = create_actor(),
77+
optimizer = ADAM(),
78+
),
79+
target_critic = NeuralNetworkApproximator(
80+
model = create_critic(),
81+
optimizer = ADAM(),
82+
),
83+
γ = 0.99f0,
84+
ρ = 0.995f0,
85+
na = na,
86+
start_steps = 1000,
87+
start_policy = RandomPolicy(-0.9..0.9; rng = rng),
88+
update_after = 1000,
89+
act_limit = 0.9,
90+
act_noise = 0.1,
91+
rng = rng,
92+
)
93+
trajectory = CircularArraySARTTrajectory(
94+
capacity = 10000, # replay buffer capacity
95+
state = Vector{Int} => (ns, ),
96+
action = Float32 => (na, ),
97+
)
98+
99+
agents = MADDPGManager(
100+
Dict((player, Agent(
101+
policy = NamedPolicy(player, deepcopy(policy)),
102+
trajectory = deepcopy(trajectory),
103+
)) for player in players(env) if player != chance_player(env)),
104+
128, # batch_size
105+
128, # update_freq
106+
0, # step_counter
107+
rng
108+
)
109+
110+
stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI"))
111+
hook = ResultNEpisode(1000, 0, [], [])
112+
Experiment(agents, wrapped_env, stop_condition, hook, "# run MADDPG on KuhnPokerEnv")
113+
end
114+
115+
#+ tangle=false
116+
using Plots
117+
ex = E`JuliaRL_MADDPG_KuhnPoker`
118+
run(ex)
119+
scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="reward of player 1")
120+
121+
savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide
122+
123+
# ![](assets/JuliaRL_MADDPG_KuhnPoker.png)

docs/experiments/experiments/Policy Gradient/config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"JuliaRL_A2C_CartPole.jl",
55
"JuliaRL_A2CGAE_CartPole.jl",
66
"JuliaRL_DDPG_Pendulum.jl",
7+
"JuliaRL_MADDPG_KuhnPoker.jl",
78
"JuliaRL_MAC_CartPole.jl",
89
"JuliaRL_PPO_CartPole.jl",
910
"JuliaRL_PPO_Pendulum.jl",
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
export MADDPGManager
2+
3+
"""
4+
MADDPGManager(; agents::Dict{<:Any, <:Agent}, args...)
5+
Multi-agent Deep Deterministic Policy Gradient(MADDPG) implemented in Julia. Here only works for simultaneous games whose action space is discrete.
6+
See the paper https://arxiv.org/abs/1706.02275 for more details.
7+
8+
# Keyword arguments
9+
- `agents::Dict{<:Any, <:NamedPolicy{<:Agent{<:DDPGPolicy, <:AbstractTrajectory}, <:Any}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network.
10+
- `batch_size::Int`
11+
- `update_freq::Int`
12+
- `update_step::Int`, count the step.
13+
- `rng::AbstractRNG`.
14+
"""
15+
mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory, N<:Any} <: AbstractPolicy
16+
agents::Dict{<:N, <:Agent{<:NamedPolicy{<:P, <:N}, <:T}}
17+
batch_size::Int
18+
update_freq::Int
19+
update_step::Int
20+
rng::AbstractRNG
21+
end
22+
23+
# for simultaneous game with a discrete action space.
24+
function::MADDPGManager)(env::AbstractEnv)
25+
while current_player(env) == chance_player(env)
26+
env |> legal_action_space |> rand |> env
27+
end
28+
Dict((player, ceil(agent.policy(env))) for (player, agent) in π.agents)
29+
end
30+
31+
function::MADDPGManager)(stage::Union{PreEpisodeStage, PostActStage}, env::AbstractEnv)
32+
# only need to update trajectory.
33+
for (_, agent) in π.agents
34+
update!(agent.trajectory, agent.policy, env, stage)
35+
end
36+
end
37+
38+
function::MADDPGManager)(stage::PreActStage, env::AbstractEnv, actions)
39+
# update each agent's trajectory.
40+
for (player, agent) in π.agents
41+
update!(agent.trajectory, agent.policy, env, stage, actions[player])
42+
end
43+
44+
# update policy
45+
update!(π)
46+
end
47+
48+
function::MADDPGManager)(stage::PostEpisodeStage, env::AbstractEnv)
49+
# collect state and a dummy action to each agent's trajectory here.
50+
for (_, agent) in π.agents
51+
update!(agent.trajectory, agent.policy, env, stage)
52+
end
53+
54+
# update policy
55+
update!(π)
56+
end
57+
58+
# update policy
59+
function RLBase.update!::MADDPGManager)
60+
π.update_step += 1
61+
π.update_step % π.update_freq == 0 || return
62+
63+
for (_, agent) in π.agents
64+
length(agent.trajectory) > agent.policy.policy.update_after || return
65+
length(agent.trajectory) > π.batch_size || return
66+
end
67+
68+
# get training data
69+
temp_player = collect(keys.agents))[1]
70+
t = π.agents[temp_player].trajectory
71+
inds = rand.rng, 1:length(t), π.batch_size)
72+
batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}.batch_size), agent.trajectory, inds))
73+
for (player, agent) in π.agents)
74+
75+
# get s, a, s′ for critic
76+
s = Flux.stack((batches[player][:state] for (player, _) in π.agents), 1)
77+
a = Flux.stack((batches[player][:action] for (player, _) in π.agents), 1)
78+
s′ = Flux.stack((batches[player][:next_state] for (player, _) in π.agents), 1)
79+
80+
# for training behavior_actor
81+
mu_actions = Flux.stack(
82+
((
83+
batches[player][:state] |> # get personal state information
84+
x -> send_to_device(device(agent.policy.policy.behavior_actor), x) |>
85+
agent.policy.policy.behavior_actor |> send_to_host
86+
) for (player, agent) in π.agents), 1
87+
)
88+
# for training behavior_critic
89+
new_actions = Flux.stack(
90+
((
91+
batches[player][:next_state] |> # get personal next_state information
92+
x -> send_to_device(device(agent.policy.policy.target_actor), x) |>
93+
agent.policy.policy.target_actor |> send_to_host
94+
) for (player, agent) in π.agents), 1
95+
)
96+
97+
for (player, agent) in π.agents
98+
p = agent.policy.policy # get DDPGPolicy struct
99+
A = p.behavior_actor
100+
C = p.behavior_critic
101+
Aₜ = p.target_actor
102+
Cₜ = p.target_critic
103+
104+
γ = p.γ
105+
ρ = p.ρ
106+
107+
_device(x) = send_to_device(device(A), x)
108+
109+
# Note that here default A, C, Aₜ, Cₜ on the same device.
110+
s, a, s′ = _device((s, a, s′))
111+
mu_actions = _device(mu_actions)
112+
new_actions = _device(new_actions)
113+
r = _device(batches[player][:reward])
114+
t = _device(batches[player][:terminal])
115+
116+
qₜ = Cₜ(vcat(s′, new_actions)) |> vec
117+
y = r .+ γ .* (1 .- t) .* qₜ
118+
119+
gs1 = gradient(Flux.params(C)) do
120+
q = C(vcat(s, a)) |> vec
121+
loss = mean((y .- q) .^ 2)
122+
ignore() do
123+
p.critic_loss = loss
124+
end
125+
loss
126+
end
127+
128+
update!(C, gs1)
129+
130+
gs2 = gradient(Flux.params(A)) do
131+
loss = -mean(C(vcat(s, mu_actions)))
132+
ignore() do
133+
p.actor_loss = loss
134+
end
135+
loss
136+
end
137+
138+
update!(A, gs2)
139+
140+
# polyak averaging
141+
for (dest, src) in zip(Flux.params([Aₜ, Cₜ]), Flux.params([A, C]))
142+
dest .= ρ .* dest .+ (1 - ρ) .* src
143+
end
144+
end
145+
end

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ include("MAC.jl")
77
include("ddpg.jl")
88
include("td3.jl")
99
include("sac.jl")
10+
include("maddpg.jl")

0 commit comments

Comments
 (0)