Skip to content

Commit 67bcfc4

Browse files
authored
Merge branch 'main' into mpo-imp
2 parents 9ebbbfa + b54a0b0 commit 67bcfc4

File tree

24 files changed

+295
-65
lines changed

24 files changed

+295
-65
lines changed

.cspell/julia_words.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5294,4 +5294,5 @@ sqmahal
52945294
logdpf
52955295
devmode
52965296
logpdfs
5297-
kldivs
5297+
kldivs
5298+
Riedmiller

docs/src/How_to_implement_a_new_algorithm.md

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646

4747
```
4848

49-
Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` subtype, its action sampling method (by overloading `Base.push!(policy::YourPolicyType, env)`) and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).
49+
Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` (or `AbstractLearner`, see [this section](#using-resources-from-rlcore)) subtype, its action sampling method (by overloading `Base.push!(policy::YourPolicyType, env)`) and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm).
5050

5151
## Using Agents
5252
The recommended way is to use the policy wrapper `Agent`. An agent is itself an `AbstractPolicy` that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `push!(agent, stage, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are
@@ -73,7 +73,7 @@ If you need a different behavior at some stages, then you can overload the `Base
7373

7474
## Updating the policy
7575

76-
Finally, you need to implement the learning function by implementing `RLBase.optimise!(::YourPolicyType, ::Stage, ::Trajectory)`. By default this does nothing at all stages. Overload it on the stage where you wish to optimise (most often, at `PreActStage` or `PostEpisodeStage`). This function should loop the trajectory to sample batches. Inside the loop, put whatever is required. For example:
76+
Finally, you need to implement the learning function by implementing `RLBase.optimise!(::YourPolicyType, ::Stage, ::Trajectory)`. By default this does nothing at all stages. Overload it on the stage where you wish to optimise (most often, at `PreActStage()`, `PostActStage()` or `PostEpisodeStage()`). This function should loop the trajectory to sample batches. Inside the loop, put whatever is required. For example:
7777

7878
```julia
7979
function RLBase.optimise!(p::YourPolicyType, ::PostEpisodeStage, traj::Trajectory)
@@ -83,7 +83,7 @@ function RLBase.optimise!(p::YourPolicyType, ::PostEpisodeStage, traj::Trajector
8383
end
8484

8585
```
86-
where `optimise!(p, batch)` is a function that will typically compute the gradient and update a neural network, or update tabular policy. What is inside the loop is free to be whatever you need. This is further discussed in the next section on `Trajectory`s.
86+
where `optimise!(p, batch)` is a function that will typically compute the gradient and update a neural network, or update a tabular policy. What is inside the loop is free to be whatever you need but it's a good idea to implement a `optimise!(p::YourPolicyType, batch::NamedTuple)` function for clarity instead of coding everything in the loop. This is further discussed in the next section on `Trajectory`s. An example of where this could be different is when you want to update priorities, see [the PER learner](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl) for an example.
8787

8888
## ReinforcementLearningTrajectories
8989

@@ -122,29 +122,32 @@ The sampler is the object that will fetch data in your trajectory to create the
122122

123123
## Using resources from RLCore
124124

125-
RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage RLCore contains a lot of utilities that you can reuse to implement your algorithm.
125+
RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage RLCore contains some utilities that you can reuse to implement your algorithm.
126126

127-
The utils folder contains utilities and extensions to external packages to fit needs that are specific to RL.jl. We will not list them all here, but it is a good idea to skim over the files to see what they contain. The policies folder notably contains several explorer implementations. Here are a few interesting examples:
127+
### QBasedPolicy
128128

129-
- `QBasedPolicy` wraps a policy that relies on a Q-Value _learner_ (tabular or approximated) and an _explorer_ .
130-
RLCore provides several pre-implemented learners and the most common explorers (such as epsilon-greedy, UCB, etc.).
129+
`QBasedPolicy` is a policy that wraps a Q-Value _learner_ (tabular or approximated) and an _explorer_. Use this wrapper to implement a policy that directly uses a Q-value function to
130+
decide its next action. In that case, instead of creating an `AbstractPolicy` subtype for your algorithm, define an `AbstractLearner` subtype and specialize `RLBase.optimise!(::YourLearnerType, ::Stage, ::Trajectory)`. This way you will not have to code the interaction between your policy and the explorer yourself.
131+
RLCore provides the most common explorers (such as epsilon-greedy, UCB, etc.).
131132

132-
- If your algorithm use tabular learners, check out the tabular_learner.jl and the tabular_approximator source files. If your algorithms uses deep neural nets then use the `NeuralNetworkApproximator` to wrap an Neural Network and an optimizer. Common policy architectures are also provided such as the `GaussianNetwork`.
133+
### Neural and linear approximators
133134

134-
- Equivalently, the `VBasedPolicy` learner is provided for algorithms that use a state-value function. Though they are not bundled in the same folder, most approximators can be used with a VBasedPolicy too.
135+
If your algorithm uses a neural network or a linear approximator to approximate a function trained with `Flux.jl`, use the `Approximator`. Approximator
136+
wraps a `Flux` model and an `Optimiser` (such as Adam or SGD). Your `optimise!(::PolicyOrLearner, batch)` function will probably consist in computing a gradient
137+
and call the `RLCore.optimise!(app::Approximator, gradient::Flux.Grads)` after that.
135138

136-
<!--- ### Batch samplers
137-
Since this is going to be outdated soon, I'll write this part later on when Trajectories.jl will be done -->
139+
Common model architectures are also provided such as the `GaussianNetwork` for continuous policies with diagonal multivariate policies; and `CovGaussianNetwork` for full covariance (very slow on GPUs at the moment).
138140

139-
- In utils/distributions.jl you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using Distributions.jl structs.
141+
### Utils
142+
In utils/distributions.jl you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using Distributions.jl structs.
140143

141144
## Conventions
142145
Finally, there are a few "conventions" and good practices that you should follow, especially if you intend to contribute to this package (don't worry we'll be happy to help if needed).
143146

144147
### Random Numbers
145-
ReinforcementLearning.jl aims to provide a framework for reproducible experiments. To do so, make sure that your policy type has a `rng` field and that all random operations (e.g. action sampling or trajectory sampling) use `rand(your_policy.rng, args...)`.
148+
ReinforcementLearning.jl aims to provide a framework for reproducible experiments. To do so, make sure that your policy type has a `rng` field and that all random operations (e.g. action sampling) use `rand(your_policy.rng, args...)`. For trajectory sampling, you can set the sampler's rng to that of the policy when creating and agent or simply instantiate its own rng.
146149

147-
### GPU friendlyness
150+
### GPU compatibility
148151
Deep RL algorithms are often much faster when the neural nets are updated on a GPU. For now, we only support CUDA.jl as a backend. This means that you will have to think about the transfer of data between the CPU (where the trajectory is) and the GPU memory (where the neural nets are). To do so you will find in utils/device.jl some functions that do most of the work for you. The ones that you need to know are `send_to_device(device, data)` that sends data to the specified device, `send_to_host(data)` which sends data to the CPU memory (it fallbacks to `send_to_device(Val{:cpu}, data)`) and `device(x)` that returns the device on which `x` is.
149152
Normally, you should be able to write a single implementation of your algorithm that works on CPU and GPUs thanks to the multiple dispatch offered by Julia.
150153

docs/src/tips.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@ dependency, remember to update both `docs/Project.toml` and
2727
All the cells after the `#+ tangle=true` line in `Your_Experment.jl` will be extracted into the
2828
`ReinforcementLearningExperiments` package automatically. This feature is
2929
supported by [Weave.jl](https://weavejl.mpastell.com/stable/usage/#tangle).
30+
31+
## How to enable debug timings for experiment runs?
32+
33+
Call `RLCore.TimerOutputs.enable_debug_timings(RLCore)` and default timings for hooks, policies and optimization steps will be printed. How do I reset the timer? Call `RLCore.TimerOutputs.reset_timer!(RLCore.timer)`. How do I show the timer results? Call `RLCore.timer`.
34+

src/ReinforcementLearningCore/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
3-
version = "0.11.0"
3+
version = "0.11.2"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -22,6 +22,7 @@ ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
2222
ReinforcementLearningTrajectories = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
2323
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2424
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
25+
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2526
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2627

2728
[compat]
@@ -41,6 +42,7 @@ Reexport = "1"
4142
ReinforcementLearningBase = "0.12"
4243
ReinforcementLearningTrajectories = "^0.1.9"
4344
StatsBase = "0.32, 0.33, 0.34"
45+
TimerOutputs = "0.5"
4446
UnicodePlots = "1.3, 2, 3"
4547
julia = "1.9"
4648

src/ReinforcementLearningCore/src/ReinforcementLearningCore.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ReinforcementLearningCore
22

3+
using TimerOutputs
34
using ReinforcementLearningBase
45
using Reexport
56

@@ -14,4 +15,7 @@ include("core/core.jl")
1415
include("policies/policies.jl")
1516
include("utils/utils.jl")
1617

18+
# Global timer for TimerOutputs.jl
19+
const timer = TimerOutput()
20+
1721
end # module

src/ReinforcementLearningCore/src/core/run.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,37 +87,38 @@ function _run(policy::AbstractPolicy,
8787
push!(policy, PreExperimentStage(), env)
8888
is_stop = false
8989
while !is_stop
90-
reset!(env)
91-
push!(policy, PreEpisodeStage(), env)
92-
optimise!(policy, PreEpisodeStage())
93-
push!(hook, PreEpisodeStage(), policy, env)
90+
# NOTE: @timeit_debug statements are used for debug logging
91+
@timeit_debug timer "reset!" reset!(env)
92+
@timeit_debug timer "push!(policy) PreEpisodeStage" push!(policy, PreEpisodeStage(), env)
93+
@timeit_debug timer "optimise! PreEpisodeStage" optimise!(policy, PreEpisodeStage())
94+
@timeit_debug timer "push!(hook) PreEpisodeStage" push!(hook, PreEpisodeStage(), policy, env)
9495

9596

9697
while !reset_condition(policy, env) # one episode
97-
push!(policy, PreActStage(), env)
98-
optimise!(policy, PreActStage())
99-
push!(hook, PreActStage(), policy, env)
98+
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
99+
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
100+
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
100101

101-
action = RLBase.plan!(policy, env)
102-
act!(env, action)
102+
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
103+
@timeit_debug timer "act!" act!(env, action)
103104

104-
push!(policy, PostActStage(), env)
105-
optimise!(policy, PostActStage())
106-
push!(hook, PostActStage(), policy, env)
105+
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
106+
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
107+
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
107108

108109
if check_stop(stop_condition, policy, env)
109110
is_stop = true
110-
push!(policy, PreActStage(), env)
111-
optimise!(policy, PreActStage())
112-
push!(hook, PreActStage(), policy, env)
113-
RLBase.plan!(policy, env) # let the policy see the last observation
111+
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
112+
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
113+
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
114+
@timeit_debug timer "plan!" RLBase.plan!(policy, env) # let the policy see the last observation
114115
break
115116
end
116117
end # end of an episode
117118

118-
push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
119-
optimise!(policy, PostEpisodeStage())
120-
push!(hook, PostEpisodeStage(), policy, env)
119+
@timeit_debug timer "push!(policy) PostEpisodeStage" push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
120+
@timeit_debug timer "optimise! PostEpisodeStage" optimise!(policy, PostEpisodeStage())
121+
@timeit_debug timer "push!(hook) PostEpisodeStage" push!(hook, PostEpisodeStage(), policy, env)
121122

122123
end
123124
push!(policy, PostExperimentStage(), env)

src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,34 +108,35 @@ function Base.run(
108108
push!(multiagent_policy, PreExperimentStage(), env)
109109
is_stop = false
110110
while !is_stop
111-
reset!(env)
112-
push!(multiagent_policy, PreEpisodeStage(), env)
113-
optimise!(multiagent_policy, PreEpisodeStage())
114-
push!(multiagent_hook, PreEpisodeStage(), multiagent_policy, env)
111+
# NOTE: @timeit_debug statements are for debug logging
112+
@timeit_debug timer "reset!" reset!(env)
113+
@timeit_debug timer "push!(policy) PreEpisodeStage" push!(multiagent_policy, PreEpisodeStage(), env)
114+
@timeit_debug timer "optimise! PreEpisodeStage" optimise!(multiagent_policy, PreEpisodeStage())
115+
@timeit_debug timer "push!(hook) PreEpisodeStage" push!(multiagent_hook, PreEpisodeStage(), multiagent_policy, env)
115116

116117
while !(reset_condition(multiagent_policy, env) || is_stop) # one episode
117118
for player in CurrentPlayerIterator(env)
118119
policy = multiagent_policy[player] # Select appropriate policy
119120
hook = multiagent_hook[player] # Select appropriate hook
120-
push!(policy, PreActStage(), env)
121-
optimise!(policy, PreActStage())
122-
push!(hook, PreActStage(), policy, env)
121+
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
122+
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
123+
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
123124

124-
action = RLBase.plan!(policy, env)
125-
act!(env, action)
125+
action = @timeit_debug timer "plan!" RLBase.plan!(policy, env)
126+
@timeit_debug timer "act!" act!(env, action)
126127

127128

128129

129-
push!(policy, PostActStage(), env)
130-
optimise!(policy, PostActStage())
131-
push!(hook, PostActStage(), policy, env)
130+
@timeit_debug timer "push!(policy) PostActStage" push!(policy, PostActStage(), env)
131+
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
132+
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
132133

133134
if check_stop(stop_condition, policy, env)
134135
is_stop = true
135-
push!(multiagent_policy, PreActStage(), env)
136-
optimise!(multiagent_policy, PreActStage())
137-
push!(multiagent_hook, PreActStage(), policy, env)
138-
RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
136+
@timeit_debug timer "push!(policy) PreActStage" push!(multiagent_policy, PreActStage(), env)
137+
@timeit_debug timer "optimise! PreActStage" optimise!(multiagent_policy, PreActStage())
138+
@timeit_debug timer "push!(hook) PreActStage" push!(multiagent_hook, PreActStage(), policy, env)
139+
@timeit_debug timer "plan!" RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
139140
break
140141
end
141142

@@ -145,9 +146,9 @@ function Base.run(
145146
end
146147
end # end of an episode
147148

148-
push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation
149-
optimise!(multiagent_policy, PostEpisodeStage())
150-
push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env)
149+
@timeit_debug timer "push!(policy) PostEpisodeStage" push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation
150+
@timeit_debug timer "optimise! PostEpisodeStage" optimise!(multiagent_policy, PostEpisodeStage())
151+
@timeit_debug timer "push!(hook) PostEpisodeStage" push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env)
151152
end
152153
push!(multiagent_policy, PostExperimentStage(), env)
153154
push!(multiagent_hook, PostExperimentStage(), multiagent_policy, env)

src/ReinforcementLearningCore/src/policies/learners.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Base.show(io::IO, m::MIME"text/plain", L::AbstractLearner) = show(io, m, convert
1010
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
1111
forward(L::Le, env::E) where {Le <: AbstractLearner, E <: AbstractEnv} = env |> state |> send_to_device(L.approximator) |> x -> forward(L, x) |> send_to_device(env)
1212

13+
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
14+
1315
Base.@kwdef mutable struct Approximator{M,O}
1416
model::M
1517
optimiser::O

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ using Functors: @functor
77

88
"""
99
QBasedPolicy(;learner, explorer)
10+
11+
Wraps a learner and an explorer. The learner is a struct that should predict the Q-value of each legal
12+
action of an environment at its current state. It is typically a table or a neural network.
13+
QBasedPolicy can be queried for an action with `RLBase.plan!`, the explorer will affect the action selection
14+
accordingly.
1015
"""
1116
Base.@kwdef mutable struct QBasedPolicy{L,E} <: AbstractPolicy
1217
"estimate the Q value"
@@ -37,8 +42,5 @@ end
3742
RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
3843
prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env))
3944

40-
function RLBase.optimise!(p::QBasedPolicy{L,Ex}, ::PostActStage, trajectory::Trajectory) where {L<:AbstractLearner,Ex<:AbstractExplorer}
41-
for batch in trajectory
42-
RLBase.optimise!(p.learner, batch)
43-
end
44-
end
45+
#the internal learner defines the optimization stage.
46+
RLBase.optimise!(p::QBasedPolicy, s::AbstractStage, trajectory::Trajectory) = RLBase.optimise!(p.learner, s, trajectory)

0 commit comments

Comments
 (0)