Skip to content

Rework the run loop #921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a80db65
bump version compat
HenriDeh Jul 6, 2023
d5209fe
bump version
HenriDeh Jul 6, 2023
a4e24ce
simplify run loop and compat with traj 0.2
HenriDeh Jul 6, 2023
597092b
rename to agent base
HenriDeh Jul 6, 2023
d2a2aaa
optional push at end of episode
HenriDeh Jul 6, 2023
629db63
use new SARST name
HenriDeh Jul 7, 2023
1deca16
bump compats
HenriDeh Jul 7, 2023
815e363
update MA plan
HenriDeh Jul 7, 2023
f930627
fix typing
HenriDeh Jul 7, 2023
6305bb0
agent typing
HenriDeh Jul 7, 2023
b2c9d1b
fix precompile
HenriDeh Jul 7, 2023
522da77
fix first tests
HenriDeh Jul 7, 2023
1ce5f53
Update docs/src/How_to_implement_a_new_algorithm.md
jeremiahpslewis Jul 11, 2023
f668bf6
Update docs/src/Zoo_Algorithms/MPO.md
jeremiahpslewis Jul 11, 2023
e695cad
deactivate VPG and TRPO
HenriDeh Jul 11, 2023
218e86b
Merge branch 'loop-traj' of https://github.com/JuliaReinforcementLear…
HenriDeh Jul 11, 2023
4198403
export sample
HenriDeh Jul 12, 2023
302f885
Merge branch 'main' into loop-traj
HenriDeh Jul 12, 2023
e8ed7b6
fix NFQ
HenriDeh Jul 12, 2023
605bcc9
remove comments
HenriDeh Jul 12, 2023
48c5173
change the MA loop
HenriDeh Jul 12, 2023
8e77654
simultaneous agents
HenriDeh Jul 12, 2023
7c25e95
Move MA stuff to proper file
HenriDeh Jul 12, 2023
a2948fa
Merge branch 'loop-traj' of https://github.com/JuliaReinforcementLear…
HenriDeh Jul 12, 2023
875da76
fix ambiguity
HenriDeh Jul 12, 2023
36d13f9
Bump RLTraj to bug fix version
jeremiahpslewis Jul 25, 2023
6460cd9
Merge branch 'main' into loop-traj
jeremiahpslewis Jul 25, 2023
21ca014
Fix type name
jeremiahpslewis Jul 26, 2023
ea877fc
Merge branch 'main' into loop-traj
jeremiahpslewis Jul 26, 2023
b938655
Drop player to clean up dispatch
jeremiahpslewis Jul 26, 2023
e710880
Add back player
jeremiahpslewis Jul 26, 2023
6ee0306
Fix state push!
jeremiahpslewis Jul 26, 2023
45df594
Update multi_agent.jl
jeremiahpslewis Jul 26, 2023
7b4f964
Broaden type signature
jeremiahpslewis Jul 26, 2023
3d95881
type signature tweak
jeremiahpslewis Jul 26, 2023
d696d78
type tweak
jeremiahpslewis Jul 26, 2023
c7d856d
Update src/ReinforcementLearningCore/Project.toml
jeremiahpslewis Jul 29, 2023
ff4b98f
Minor tweaks
jeremiahpslewis Jul 29, 2023
550222b
Require RLTraj bug fix
jeremiahpslewis Jul 29, 2023
aa55642
Require bug-fixed RLTrajectories
jeremiahpslewis Jul 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14804,7 +14804,7 @@ <h2 id="Understand-the-Trajectories">Understand the <em>Trajectories</em><a clas
<div class="prompt input_prompt">In&nbsp;[28]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">t</span> <span class="o">=</span> <span class="n">Trajectories</span><span class="o">.</span><span class="n">CircularArraySARTTraces</span><span class="p">(;</span><span class="n">capacity</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<div class=" highlight hl-julia"><pre><span></span><span class="n">t</span> <span class="o">=</span> <span class="n">Trajectories</span><span class="o">.</span><span class="n">CircularArraySARTSTraces</span><span class="p">(;</span><span class="n">capacity</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
</pre></div>

</div>
Expand Down
2 changes: 1 addition & 1 deletion docs/src/How_to_implement_a_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ A `Trajectory` is composed of three elements: a `container`, a `controller`, and

The container is typically an `AbstractTraces`, an object that store a set of `Trace` in a structured manner. You can either define your own (and contribute it to the package if it is likely to be usable for other algorithms), or use a predefined one if it exists.

The most common `AbstractTraces` object is the `CircularArraySARTTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which toghether are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.
The most common `AbstractTraces` object is the `CircularArraySARTSTraces`, this is a container of a fixed length that stores the following traces: `:state` (S), `:action` (A), `:reward` (R), `:terminal` (T), which together are aliased to `SART = (:state, :action, :reward, :terminal)`. Let us see how it is constructed in this simplified version as an example of how to build a custom trace.

```julia
function (capacity, state_size, state_eltype, action_size, action_eltype, reward_eltype)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/Zoo_Algorithms/MPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ The next step is to wrap this policy into an `Agent`. An agent is a combination

```julia
trajectory = Trajectory(
CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,),action = Float32 => (1,)),
CircularArraySARTSTraces(capacity = 1000, state = Float32 => (4,), action = Float32 => (1,)),
MetaSampler(
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 1000)
Expand Down
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.11.3"
version = "0.12.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -40,7 +40,7 @@ Parsers = "2"
ProgressMeter = "1"
Reexport = "1"
ReinforcementLearningBase = "0.12"
ReinforcementLearningTrajectories = "^0.1.9"
ReinforcementLearningTrajectories = "^0.3.2"
StatsBase = "0.32, 0.33, 0.34"
TimerOutputs = "0.5"
UnicodePlots = "1.3, 2, 3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module ReinforcementLearningCore
using TimerOutputs
using ReinforcementLearningBase
using Reexport

const RLCore = ReinforcementLearningCore

export RLCore
Expand Down
8 changes: 2 additions & 6 deletions src/ReinforcementLearningCore/src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,17 @@ function _run(policy::AbstractPolicy,
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
@timeit_debug timer "act!" act!(env, action)

@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action)
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)

if check_stop(stop_condition, policy, env)
is_stop = true
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
@timeit_debug timer "plan!" RLBase.plan!(policy, env) # let the policy see the last observation
break
end
end # end of an episode

@timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
@timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env)
@timeit_debug timer "optimise! PostEpisodeStage" optimise!(policy, PostEpisodeStage())
@timeit_debug timer "push!(hook) PostEpisodeStage" push!(hook, PostEpisodeStage(), policy, env)

Expand Down
2 changes: 2 additions & 0 deletions src/ReinforcementLearningCore/src/core/stages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ struct PreActStage <: AbstractStage end
struct PostActStage <: AbstractStage end

Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing
Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action) = nothing
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing
Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action, ::Symbol) = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is begging for us to create an action type, but that's something for another PR. :)


RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing

Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningCore/src/policies/agent/agent.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include("base.jl")
include("agent_base.jl")
include("agent_srt_cache.jl")
include("multi_agent.jl")
64 changes: 64 additions & 0 deletions src/ReinforcementLearningCore/src/policies/agent/agent_base.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
export Agent

using Base.Threads: @spawn

using Functors: @functor
import Base.push!
"""
Agent(;policy, trajectory) <: AbstractPolicy

A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to
update the trajectory and policy appropriately in different stages. Agent
is a Callable and its call method accepts varargs and keyword arguments to be
passed to the policy.

"""
mutable struct Agent{P,T} <: AbstractPolicy
policy::P
trajectory::T

function Agent(policy::P, trajectory::T) where {P<:AbstractPolicy, T<:Trajectory}
agent = new{P,T}(policy, trajectory)

if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle()
bind(trajectory, @spawn(optimise!(policy, trajectory)))
end
agent
end
end

Agent(;policy, trajectory) = Agent(policy, trajectory)

RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory)

# already spawn a task to optimise inner policy when initializing the agent
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing

#by default, optimise does nothing at all stage
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end

@functor Agent (policy,)

function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv)
push!(agent.trajectory, (state = state(env),))
end

# !!! TODO: In async scenarios, parameters of the policy may still be updating
# (partially), which will result to incorrect action. This should be addressed
# in Oolong.jl with a wrapper
function RLBase.plan!(agent::Agent, env::AbstractEnv)
RLBase.plan!(agent.policy, env)
end

function Base.push!(agent::Agent, ::PostActStage, env::AbstractEnv, action)
next_state = state(env)
push!(agent.trajectory, (state = next_state, action = action, reward = reward(env), terminal = is_terminated(env)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I got sth wrong, but should next_state not stored in next_state field of the trajectory? next_state is successor of the state before the action was done in the environment, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same. Both names point to the same Trace in the trajectory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. It is the multiplex trace, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

end

function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv)
if haskey(agent.trajectory, :next_action)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is if the episode finished (whether truncated or terminated) we query the policy to plan another step. We should also check if the environment is not terminated? If it is, it just makes no sense to plan an action.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't make sense indeed, but if your environment has terminal states at all, then you should not use a trajectory that has a next_action key. That's the how I thought about it. If we add that check, then it allows the user to have an incorrect trajectory without an error being thrown and the buffer will accumulate mistakes.

action = RLBase.plan!(agent.policy, env)
push!(agent.trajectory, PartialNamedTuple((action = action, )))
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ struct SART{S,A,R,T}
end

# This method is used to push a state and action to a trace
function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SA)
function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SA)
push!(ts.traces[1].trace, xs.state)
push!(ts.traces[2].trace, xs.action)
end

function Base.push!(ts::Union{CircularArraySARTTraces,ElasticArraySARTTraces}, xs::SART)
function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SART)
push!(ts.traces[1].trace, xs.state)
push!(ts.traces[2].trace, xs.action)
push!(ts.traces[3], xs.reward)
Expand Down
89 changes: 0 additions & 89 deletions src/ReinforcementLearningCore/src/policies/agent/base.jl

This file was deleted.

49 changes: 33 additions & 16 deletions src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,12 @@ function Base.run(
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
@timeit_debug timer "act!" act!(env, action)



@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env, action)
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)

if check_stop(stop_condition, policy, env)
is_stop = true
@timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env)
@timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage())
@timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env)
@timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
break
end

Expand Down Expand Up @@ -191,21 +185,43 @@ function Base.push!(multiagent::MultiAgentPolicy, stage::S, env::E) where {S<:Ab
end
end

# Like in the single-agent case, push! at the PreActStage() calls push! on each player with the state of the environment
function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PreActStage, env::E) where {E<:AbstractEnv, names, T <: Agent}
# Like in the single-agent case, push! at the PostActStage() calls push! on each player.
function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv, player::Symbol)
push!(agent.trajectory, (state = state(env, player),))
end

function Base.push!(multiagent::MultiAgentPolicy, s::PreEpisodeStage, env::E) where {E<:AbstractEnv}
for player in players(env)
push!(multiagent[player], state(env, player))
push!(multiagent[player], s, env, player)
end
end

# Like in the single-agent case, push! at the PostActStage() calls push! on each player with the reward and termination status of the environment
function Base.push!(multiagent::MultiAgentPolicy{names, T}, ::PostActStage, env::E) where {E<:AbstractEnv, names, T <: Agent}
for player in players(env)
push!(multiagent[player].cache, reward(env, player), is_terminated(env))
function RLBase.plan!(agent::Agent, env::AbstractEnv, player::Symbol)
RLBase.plan!(agent.policy, env, player)
end

# Like in the single-agent case, push! at the PostActStage() calls push! on each player to store the action, reward, next_state, and terminal signal.
function Base.push!(multiagent::MultiAgentPolicy, ::PostActStage, env::E, actions) where {E<:AbstractEnv}
for (player, action) in zip(players(env), actions)
next_state = state(env, player)
observation = (
state = next_state,
action = action,
reward = reward(env, player),
terminal = is_terminated(env)
)
push!(multiagent[player].trajectory, observation)
end
end

function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv, p::Symbol)
if haskey(agent.trajectory, :next_action)
action = RLBase.plan!(agent.policy, env, p)
push!(agent.trajectory, PartialNamedTuple((action = action, )))
end
end

function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv,S<:AbstractStage}
function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv, S<:AbstractStage}
for player in players(env)
push!(hook[player], stage, multiagent[player], env, player)
end
Expand All @@ -227,8 +243,9 @@ function Base.push!(composed_hook::ComposedHook{T},
_push!(stage, policy, env, player, composed_hook.hooks...)
end

#For simultaneous players, plan! returns a Tuple of actions.
function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEnv}
return (RLBase.plan!(multiagent[player], env, player) for player in players(env))
return Tuple(RLBase.plan!(multiagent[player], env, player) for player in players(env))
end

function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}
Expand Down
Loading