-
-
Notifications
You must be signed in to change notification settings - Fork 109
Rework the run loop #921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework the run loop #921
Changes from all commits
a80db65
d5209fe
a4e24ce
597092b
d2a2aaa
629db63
1deca16
815e363
f930627
6305bb0
b2c9d1b
522da77
1ce5f53
f668bf6
e695cad
218e86b
4198403
302f885
e8ed7b6
605bcc9
48c5173
8e77654
7c25e95
a2948fa
875da76
36d13f9
6460cd9
21ca014
ea877fc
b938655
e710880
6ee0306
45df594
7b4f964
3d95881
d696d78
c7d856d
ff4b98f
550222b
aa55642
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
include("base.jl") | ||
include("agent_base.jl") | ||
include("agent_srt_cache.jl") | ||
include("multi_agent.jl") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
export Agent | ||
|
||
using Base.Threads: @spawn | ||
|
||
using Functors: @functor | ||
import Base.push! | ||
""" | ||
Agent(;policy, trajectory) <: AbstractPolicy | ||
|
||
A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to | ||
update the trajectory and policy appropriately in different stages. Agent | ||
is a Callable and its call method accepts varargs and keyword arguments to be | ||
passed to the policy. | ||
|
||
""" | ||
mutable struct Agent{P,T} <: AbstractPolicy | ||
policy::P | ||
trajectory::T | ||
|
||
function Agent(policy::P, trajectory::T) where {P<:AbstractPolicy, T<:Trajectory} | ||
agent = new{P,T}(policy, trajectory) | ||
|
||
if TrajectoryStyle(trajectory) === AsyncTrajectoryStyle() | ||
bind(trajectory, @spawn(optimise!(policy, trajectory))) | ||
end | ||
agent | ||
end | ||
end | ||
|
||
Agent(;policy, trajectory) = Agent(policy, trajectory) | ||
|
||
RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(TrajectoryStyle(agent.trajectory), agent, stage) | ||
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = RLBase.optimise!(agent.policy, stage, agent.trajectory) | ||
|
||
# already spawn a task to optimise inner policy when initializing the agent | ||
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing | ||
|
||
#by default, optimise does nothing at all stage | ||
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end | ||
|
||
@functor Agent (policy,) | ||
|
||
function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv) | ||
push!(agent.trajectory, (state = state(env),)) | ||
end | ||
|
||
# !!! TODO: In async scenarios, parameters of the policy may still be updating | ||
# (partially), which will result to incorrect action. This should be addressed | ||
# in Oolong.jl with a wrapper | ||
function RLBase.plan!(agent::Agent, env::AbstractEnv) | ||
RLBase.plan!(agent.policy, env) | ||
end | ||
|
||
function Base.push!(agent::Agent, ::PostActStage, env::AbstractEnv, action) | ||
next_state = state(env) | ||
push!(agent.trajectory, (state = next_state, action = action, reward = reward(env), terminal = is_terminated(env))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I got sth wrong, but should next_state not stored in next_state field of the trajectory? next_state is successor of the state before the action was done in the environment, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the same. Both names point to the same Trace in the trajectory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. It is the multiplex trace, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
end | ||
|
||
function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv) | ||
if haskey(agent.trajectory, :next_action) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is if the episode finished (whether truncated or terminated) we query the policy to plan another step. We should also check if the environment is not terminated? If it is, it just makes no sense to plan an action. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It wouldn't make sense indeed, but if your environment has terminal states at all, then you should not use a trajectory that has a next_action key. That's the how I thought about it. If we add that check, then it allows the user to have an incorrect trajectory without an error being thrown and the buffer will accumulate mistakes. |
||
action = RLBase.plan!(agent.policy, env) | ||
push!(agent.trajectory, PartialNamedTuple((action = action, ))) | ||
end | ||
end |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is begging for us to create an
action
type, but that's something for another PR. :)