-
-
Notifications
You must be signed in to change notification settings - Fork 109
add MADDPG algorithm #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MADDPG algorithm #444
Conversation
You can simply add those words after ReinforcementLearning.jl/.cspell/cspell.json Line 123 in 4973762
cc @pilgrimygy |
Thanks! And ask for suggestions about the implementation and code mistakes/style. |
I'll review it later tonight 😃 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me
function (π::MADDPGManager)(::PreEpisodeStage, ::AbstractEnv) | ||
for (_, agent) in π.agents | ||
if length(agent.trajectory) > 0 | ||
pop!(agent.trajectory[:state]) | ||
pop!(agent.trajectory[:action]) | ||
if haskey(agent.trajectory, :legal_actions_mask) | ||
pop!(agent.trajectory[:legal_actions_mask]) | ||
end | ||
end | ||
end | ||
end | ||
|
||
function (π::MADDPGManager)(::PreActStage, env::AbstractEnv, actions) | ||
# update each agent's trajectory | ||
for (player, agent) in π.agents | ||
push!(agent.trajectory[:state], state(env, player)) | ||
push!(agent.trajectory[:action], actions[player]) | ||
if haskey(agent.trajectory, :legal_actions_mask) | ||
lasm = legal_action_space_mask(env, player) | ||
push!(agent.trajectory[:legal_actions_mask], lasm) | ||
end | ||
end | ||
|
||
# update policy | ||
update!(π) | ||
end | ||
|
||
function (π::MADDPGManager)(::PostActStage, env::AbstractEnv) | ||
for (player, agent) in π.agents | ||
push!(agent.trajectory[:reward], reward(env, player)) | ||
push!(agent.trajectory[:terminal], is_terminated(env)) | ||
end | ||
end | ||
|
||
function (π::MADDPGManager)(::PostEpisodeStage, env::AbstractEnv) | ||
# collect state and dummy action to each agent's trajectory | ||
for (player, agent) in π.agents | ||
push!(agent.trajectory[:state], state(env, player)) | ||
push!(agent.trajectory[:action], rand(action_space(env))) | ||
if haskey(agent.trajectory, :legal_actions_mask) | ||
lasm = legal_action_space_mask(env, player) | ||
push!(agent.trajectory[:legal_actions_mask], lasm) | ||
end | ||
end | ||
|
||
# update policy | ||
update!(π) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about dispatching to the inner agent's corresponding methods?
Like calling agent(stage, env, action)
in the for
loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take a look at the NamedPolicy
and see whether we can reuse existing code as much as possible? See also the MultiAgentManager
temp_player = rand(keys(π.agents)) | ||
t = π.agents[temp_player].trajectory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply use the first agent?
temp_player = rand(keys(π.agents)) | ||
t = π.agents[temp_player].trajectory | ||
inds = rand(π.rng, 1:length(t), π.batch_size) | ||
batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}(π.batch_size), agent.trajectory, inds)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded SARTS
will make the algorithm work only on environments of MINIMAL_ACTION_SET
.
s = vcat((batches[player][1] for (player, _) in π.agents)...) | ||
a = vcat((batches[player][2] for (player, _) in π.agents)...) | ||
s′ = vcat((batches[player][5] for (player, _) in π.agents)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vcat
is not very efficient here. Try Flux.batch
?
s, a, s′ = send_to_host((s, a, s′)) | ||
mu_actions = send_to_host(mu_actions) | ||
new_actions = send_to_host(new_actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are they required here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your kind reviews! I'll check and update my codes later today.
Here is still a simple version of |
PR Checklist
The description of the implementation is in discussion #404.
Here
MADDPG
raises anunknown word
error... How can I fix it? @findmyway