Skip to content

Episode reset condition #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ makedocs(
"How to implement a new algorithm?" => "How_to_implement_a_new_algorithm.md",
"How to use hooks?" => "How_to_use_hooks.md",
"Which algorithm should I use?" => "Which_algorithm_should_I_use.md",
"Episodic vs. Non-episodic environments" => "non_episodic.md",
],
"FAQ" => "FAQ.md",
experiments,
Expand Down
46 changes: 46 additions & 0 deletions docs/src/non_episodic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Episodic vs Non-episodic environments

## Episodic environments
By default, `run(policy, env, stop_condition, hook)` will step through `env` until a terminal state is reached, signaling the end of an episode. To be able to do so, `env` must implement the `RLBase.is_terminated(::YourEnvironment)` function. This function is called after each step through the environment and when it returns `true`, the trajectory records the terminal state, then the `RLBase.reset!(::YourEnvironment)` function is called and the environment is set to (one of) its initial state(s).

Using this means that the value of the terminal state is set to 0 when learning its value via boostrapping.

## Non-episodic environment

Also called _Continuing tasks_ (Sutton & Barto, 2018), non-episodic environment do not have a terminal state and thus may run for ever, or until the `stop_condition` is reached. Sometimes however, one may want to periodically reset the environment to start fresh. A first possibility is to implement `RLBase.is_terminated(::YourEnvironment)` to reset according to an arbitrary condition. However this may not be a good idea because the value of the last state (note that it is not a _terminal_ state) will be bootstrapped to 0 during learning, even though it is not the true value of the state.

To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to `run(policy, env, stop_condition, hook, reset_condition = ResetAtTerminal())`. The default `ResetAtTerminal()` assumes an episodic environment, changing that to `ResetAfterNSteps(n)` will no longer check `is_terminated` but will instead call `reset!` every `n` steps. This way, the value of the last state will not be multiplied by 0 during bootstrapping and the correct value can be learned.

## Custom reset conditions

You can specify a custom `reset_condition` instead of using the built-in's. Your condition must be callable with the method `my_condition(policy, env)`. For example, here is how to implement a custom condition that checks for a terminal state but will also reset if the episode is too long:

```julia
reset_n_steps = ResetAfterNSteps(10000)

function my_condition(policy, env)
terminal = is_terminated(env)
too_long = reset_n_steps(policy, env)
return terminal || too_long
end

run(agent, env, stop_condition, hook, my_condition)
```

We can instead make a callable struct instead of a function to avoid the global `reset_n_step`.

```julia
mutable struct MyCondition
reset_after
end

(c::MyCondition)(policy, env) = is_terminated(env) || c.reset_after(policy, env)

run(agent, env, stop_condition, hook, MyCondition(ResetAfterNSteps(10000)))
```

A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetAtTerminal`:

```julia
run(agent, env, stop_condition, hook, (p,e) -> is_terminated(e))
```
1 change: 1 addition & 0 deletions src/ReinforcementLearningCore/src/core/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ include("stages.jl")
include("stop_conditions.jl")
include("hooks.jl")
include("run.jl")
include("reset_conditions.jl")
33 changes: 33 additions & 0 deletions src/ReinforcementLearningCore/src/core/reset_conditions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
export ResetAtTerminal, ResetAfterNSteps

"""
ResetAtTerminal()

A reset condition that resets the environment if is_terminated(env) is true.
"""
struct ResetAtTerminal end

(::ResetAtTerminal)(policy, env) = is_terminated(env)

"""
ResetAfterNSteps(n)

A reset condition that resets the environment after `n` steps.
"""
mutable struct ResetAfterNSteps
t::Int
n::Int
end

ResetAfterNSteps(n::Int) = ResetAfterNSteps(0, n)

function (r::ResetAfterNSteps)(policy, env)
stop = r.t >= r.n
r.t += 1
if stop
r.t = 0
return true
else
return false
end
end
11 changes: 6 additions & 5 deletions src/ReinforcementLearningCore/src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,18 @@ Base.run(ex::Experiment) = run(ex.policy, ex.env, ex.stop_condition, ex.hook)
function Base.run(
policy::AbstractPolicy,
env::AbstractEnv,
stop_condition=StopAfterEpisode(1),
hook=EmptyHook(),
stop_condition = StopAfterEpisode(1),
hook = EmptyHook(),
reset_condition = ResetAtTerminal()
)
policy, env = check(policy, env)
_run(policy, env, stop_condition, hook)
_run(policy, env, stop_condition, hook, reset_condition)
end

"Inject some customized checkings here by overwriting this function"
check(policy, env) = policy, env

function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook)
function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook, reset_condition)

hook(PreExperimentStage(), policy, env)
policy(PreExperimentStage(), env)
Expand All @@ -80,7 +81,7 @@ function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook)
policy(PreEpisodeStage(), env)
hook(PreEpisodeStage(), policy, env)

while !is_terminated(env) # one episode
while !reset_condition(policy, env) # one episode
policy(PreActStage(), env)
hook(PreActStage(), policy, env)

Expand Down