Skip to content

Commit 798a5dd

Browse files
committed
use NamedTuples for plan!, Base.push! for sim. env
1 parent b54a0b0 commit 798a5dd

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/ReinforcementLearningCore/src/policies/agent/base.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajecto
5151

5252
@functor Agent (policy,)
5353

54-
function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv)
55-
push!(agent, state(env))
56-
end
57-
5854
# !!! TODO: In async scenarios, parameters of the policy may still be updating
5955
# (partially), which will result to incorrect action. This should be addressed
6056
# in Oolong.jl with a wrapper
@@ -64,11 +60,8 @@ function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C}
6460
action
6561
end
6662

67-
# Multiagent Version
68-
function RLBase.plan!(agent::Agent{P,T,C}, env::E, p::Symbol) where {P,T,C,E<:AbstractEnv}
69-
action = RLBase.plan!(agent.policy, env, p)
70-
push!(agent.trajectory, agent.cache, action)
71-
action
63+
function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv)
64+
push!(agent, state(env))
7265
end
7366

7467
function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E) where {P,T,C,E<:AbstractEnv}
@@ -79,11 +72,26 @@ function Base.push!(agent::Agent, ::PostExperimentStage, env::E) where {E<:Abstr
7972
RLBase.reset!(agent.cache)
8073
end
8174

82-
function Base.push!(agent::Agent, ::PostExperimentStage, env::E, player::Symbol) where {E<:AbstractEnv}
83-
RLBase.reset!(agent.cache)
84-
end
85-
8675
function Base.push!(agent::Agent{P,T,C}, state::S) where {P,T,C,S}
8776
push!(agent.cache, state)
8877
end
8978

79+
# Multiagent Version
80+
function RLBase.plan!(agent::Agent{P,T,C}, env::E, p::Symbol) where {P,T,C,E<:AbstractEnv}
81+
action = RLBase.plan!(agent.policy, env, p)
82+
push!(agent.trajectory, agent.cache, action)
83+
action
84+
end
85+
86+
# for simultaneous DynamicStyle environments, we have to define push! operations
87+
function Base.push!(agent::Agent, ::PreActStage, env::AbstractEnv, player::Symbol)
88+
push!(agent, state(env, player))
89+
end
90+
91+
function Base.push!(agent::Agent{P,T,C}, ::PostActStage, env::E, player::Symbol) where {P,T,C,E<:AbstractEnv}
92+
push!(agent.cache, reward(env, player), is_terminated(env))
93+
end
94+
95+
function Base.push!(agent::Agent, ::PostExperimentStage, env::E, player::Symbol) where {E<:AbstractEnv}
96+
RLBase.reset!(agent.cache)
97+
end

src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,12 @@ function Base.run(
133133

134134
if check_stop(stop_condition, policy, env)
135135
is_stop = true
136-
@timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env)
137-
@timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage())
138-
@timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env)
139-
@timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
136+
if !is_terminated(env)
137+
@timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env)
138+
@timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage())
139+
@timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env)
140+
@timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
141+
end
140142
break
141143
end
142144

@@ -228,7 +230,7 @@ function Base.push!(composed_hook::ComposedHook{T},
228230
end
229231

230232
function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
231-
return (RLBase.plan!(multiagent[player], env, player) for player in players(env))
233+
return NamedTuple(player => RLBase.plan!(multiagent[player], env, player) for player in players(env))
232234
end
233235

234236
function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}

0 commit comments

Comments
 (0)