Skip to content

Commit 89fe660

Browse files
authored
Merge pull request #621 from HenriDeh/EpisodeResetCondition
Episode reset condition
2 parents b58c7c4 + 8732fb8 commit 89fe660

File tree

5 files changed

+87
-5
lines changed

5 files changed

+87
-5
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ makedocs(
4949
"How to implement a new algorithm?" => "How_to_implement_a_new_algorithm.md",
5050
"How to use hooks?" => "How_to_use_hooks.md",
5151
"Which algorithm should I use?" => "Which_algorithm_should_I_use.md",
52+
"Episodic vs. Non-episodic environments" => "non_episodic.md",
5253
],
5354
"FAQ" => "FAQ.md",
5455
experiments,

docs/src/non_episodic.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Episodic vs Non-episodic environments
2+
3+
## Episodic environments
4+
By default, `run(policy, env, stop_condition, hook)` will step through `env` until a terminal state is reached, signaling the end of an episode. To be able to do so, `env` must implement the `RLBase.is_terminated(::YourEnvironment)` function. This function is called after each step through the environment and when it returns `true`, the trajectory records the terminal state, then the `RLBase.reset!(::YourEnvironment)` function is called and the environment is set to (one of) its initial state(s).
5+
6+
Using this means that the value of the terminal state is set to 0 when learning its value via boostrapping.
7+
8+
## Non-episodic environment
9+
10+
Also called _Continuing tasks_ (Sutton & Barto, 2018), non-episodic environment do not have a terminal state and thus may run for ever, or until the `stop_condition` is reached. Sometimes however, one may want to periodically reset the environment to start fresh. A first possibility is to implement `RLBase.is_terminated(::YourEnvironment)` to reset according to an arbitrary condition. However this may not be a good idea because the value of the last state (note that it is not a _terminal_ state) will be bootstrapped to 0 during learning, even though it is not the true value of the state.
11+
12+
To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to `run(policy, env, stop_condition, hook, reset_condition = ResetAtTerminal())`. The default `ResetAtTerminal()` assumes an episodic environment, changing that to `ResetAfterNSteps(n)` will no longer check `is_terminated` but will instead call `reset!` every `n` steps. This way, the value of the last state will not be multiplied by 0 during bootstrapping and the correct value can be learned.
13+
14+
## Custom reset conditions
15+
16+
You can specify a custom `reset_condition` instead of using the built-in's. Your condition must be callable with the method `my_condition(policy, env)`. For example, here is how to implement a custom condition that checks for a terminal state but will also reset if the episode is too long:
17+
18+
```julia
19+
reset_n_steps = ResetAfterNSteps(10000)
20+
21+
function my_condition(policy, env)
22+
terminal = is_terminated(env)
23+
too_long = reset_n_steps(policy, env)
24+
return terminal || too_long
25+
end
26+
27+
run(agent, env, stop_condition, hook, my_condition)
28+
```
29+
30+
We can instead make a callable struct instead of a function to avoid the global `reset_n_step`.
31+
32+
```julia
33+
mutable struct MyCondition
34+
reset_after
35+
end
36+
37+
(c::MyCondition)(policy, env) = is_terminated(env) || c.reset_after(policy, env)
38+
39+
run(agent, env, stop_condition, hook, MyCondition(ResetAfterNSteps(10000)))
40+
```
41+
42+
A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetAtTerminal`:
43+
44+
```julia
45+
run(agent, env, stop_condition, hook, (p,e) -> is_terminated(e))
46+
```

src/ReinforcementLearningCore/src/core/core.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("stages.jl")
22
include("stop_conditions.jl")
33
include("hooks.jl")
44
include("run.jl")
5+
include("reset_conditions.jl")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
export ResetAtTerminal, ResetAfterNSteps
2+
3+
"""
4+
ResetAtTerminal()
5+
6+
A reset condition that resets the environment if is_terminated(env) is true.
7+
"""
8+
struct ResetAtTerminal end
9+
10+
(::ResetAtTerminal)(policy, env) = is_terminated(env)
11+
12+
"""
13+
ResetAfterNSteps(n)
14+
15+
A reset condition that resets the environment after `n` steps.
16+
"""
17+
mutable struct ResetAfterNSteps
18+
t::Int
19+
n::Int
20+
end
21+
22+
ResetAfterNSteps(n::Int) = ResetAfterNSteps(0, n)
23+
24+
function (r::ResetAfterNSteps)(policy, env)
25+
stop = r.t >= r.n
26+
r.t += 1
27+
if stop
28+
r.t = 0
29+
return true
30+
else
31+
return false
32+
end
33+
end

src/ReinforcementLearningCore/src/core/run.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,18 @@ Base.run(ex::Experiment) = run(ex.policy, ex.env, ex.stop_condition, ex.hook)
6060
function Base.run(
6161
policy::AbstractPolicy,
6262
env::AbstractEnv,
63-
stop_condition=StopAfterEpisode(1),
64-
hook=EmptyHook(),
63+
stop_condition = StopAfterEpisode(1),
64+
hook = EmptyHook(),
65+
reset_condition = ResetAtTerminal()
6566
)
6667
policy, env = check(policy, env)
67-
_run(policy, env, stop_condition, hook)
68+
_run(policy, env, stop_condition, hook, reset_condition)
6869
end
6970

7071
"Inject some customized checkings here by overwriting this function"
7172
check(policy, env) = policy, env
7273

73-
function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook)
74+
function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook, reset_condition)
7475

7576
hook(PreExperimentStage(), policy, env)
7677
policy(PreExperimentStage(), env)
@@ -80,7 +81,7 @@ function _run(policy::AbstractPolicy, env::AbstractEnv, stop_condition, hook)
8081
policy(PreEpisodeStage(), env)
8182
hook(PreEpisodeStage(), policy, env)
8283

83-
while !is_terminated(env) # one episode
84+
while !reset_condition(policy, env) # one episode
8485
policy(PreActStage(), env)
8586
hook(PreActStage(), policy, env)
8687

0 commit comments

Comments
 (0)