Skip to content

Commit 1a0b458

Browse files
committed
Merge branch 'main' into normalization
2 parents 564cb5b + 9095619 commit 1a0b458

20 files changed

+883
-355
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
coverage/
12
*.jl.*.cov
23
*.jl.cov
34
*.jl.mem
45
/Manifest.toml
6+
7+
.DS_Store

Project.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
name = "Trajectories"
1+
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
authors = ["Jun Tian <[email protected]> and contributors"]
43
version = "0.1.0"
54

65
[deps]
76
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
8-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
97
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
108
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
119
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12-
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
10+
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
1311

1412
[compat]
1513
CircularArrayBuffers = "0.1"
16-
Term = "0.3"
14+
MacroTools = "0.5"
15+
StackViews = "0.1"
1716
julia = "1.6"
1817
OnlineStats = "1.0"
1918

README.md

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,82 @@
1-
# Trajectories
1+
# ReinforcementLearningTrajectories
22

3-
[![Build Status](https://github.com/JuliaReinforcementLearning/Trajectories.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaReinforcementLearning/Trajectories.jl/actions/workflows/CI.yml?query=branch%3Amain)
4-
[![Coverage](https://codecov.io/gh/JuliaReinforcementLearning/Trajectories.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaReinforcementLearning/Trajectories.jl)
3+
[![Build Status](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl/actions/workflows/CI.yml?query=branch%3Amain)
4+
[![Coverage](https://codecov.io/gh/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaReinforcementLearning/ReinforcementLearningTrajectories.jl)
55
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/T/Trajectories.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/report.html)
66

77
## Design
88

9-
A typical example of `Trajectory`:
9+
The relationship of several concepts provided in this package:
1010

11-
![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)
11+
```
12+
┌───────────────────────────────────┐
13+
│ Trajectory │
14+
│ ┌───────────────────────────────┐ │
15+
│ │ AbstractTraces │ │
16+
│ │ ┌───────────────┐ │ │
17+
│ │ :trace_A => │ AbstractTrace │ │ │
18+
│ │ └───────────────┘ │ │
19+
│ │ │ │
20+
│ │ ┌───────────────┐ │ │
21+
│ │ :trace_B => │ AbstractTrace │ │ │
22+
│ │ └───────────────┘ │ │
23+
│ │ ... ... │ │
24+
│ └───────────────────────────────┘ │
25+
│ ┌───────────┐ │
26+
│ │ Sampler │ │
27+
│ └───────────┘ │
28+
│ ┌────────────┐ │
29+
│ │ Controller │ │
30+
│ └────────────┘ │
31+
└───────────────────────────────────┘
32+
```
33+
34+
## `Trajectory`
35+
36+
A `Trajectory` contains 3 parts:
1237

13-
Exported APIs are:
38+
- A `container` to store data. (Usually an `AbstractTraces`)
39+
- A `sampler` to determine how to sample a batch from `container`
40+
- A `controller` to decide when to sample a new batch from the `container`
41+
42+
Typical usage:
1443

1544
```julia
16-
push!(trajectory; [trace_name=value]...)
17-
append!(trajectory; [trace_name=value]...)
45+
julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3));
46+
47+
julia> for i in 1:5
48+
push!(t, (a=i, b=iseven(i)))
49+
end
1850

19-
for sample in trajectory
20-
# consume samples from the trajectory
21-
end
51+
julia> for batch in t
52+
println(batch)
53+
end
54+
(a = [4, 5, 1], b = Bool[1, 0, 0])
55+
(a = [3, 2, 4], b = Bool[0, 1, 1])
56+
(a = [4, 1, 2], b = Bool[1, 0, 1])
2257
```
2358

24-
A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.
59+
**Traces**
60+
61+
- `Traces`
62+
- `MultiplexTraces`
63+
- `CircularSARTTraces`
64+
- `Episode`
65+
- `Episodes`
66+
67+
**Samplers**
68+
69+
- `BatchSampler`
70+
- `MetaSampler`
71+
- `MultiBatchSampler`
72+
73+
**Controllers**
74+
75+
- `InsertSampleRatioController`
76+
- `AsyncInsertSampleRatioController`
77+
78+
79+
Please refer tests for common usage. (TODO: generate docs and add links to above data structures)
2580

2681
## Acknowledgement
2782

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module ReinforcementLearningTrajectories
2+
3+
const RLTrajectories = ReinforcementLearningTrajectories
4+
export RLTrajectories
5+
6+
include("patch.jl")
7+
8+
include("traces.jl")
9+
include("samplers.jl")
10+
include("controllers.jl")
11+
include("trajectory.jl")
12+
include("common/common.jl")
13+
14+
end

src/Trajectories.jl

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/common/CircularArraySARTTraces.jl

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
export CircularArraySARTTraces
22

33
const CircularArraySARTTraces = Traces{
4-
SART,
4+
SSAART,
55
<:Tuple{
6+
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7+
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
68
<:Trace{<:CircularArrayBuffer},
79
<:Trace{<:CircularArrayBuffer},
8-
<:Trace{<:CircularArrayBuffer},
9-
<:Trace{<:CircularArrayBuffer}
1010
}
1111
}
1212

13-
1413
function CircularArraySARTTraces(;
1514
capacity::Int,
1615
state=Int => (),
@@ -23,32 +22,10 @@ function CircularArraySARTTraces(;
2322
reward_eltype, reward_size = reward
2423
terminal_eltype, terminal_size = terminal
2524

25+
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
26+
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
2627
Traces(
27-
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
28-
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
2928
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
3029
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3130
)
3231
end
33-
34-
function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
35-
inds = rand(s.rng, 1:length(t), s.batch_size)
36-
inds′ = inds .+ 1
37-
(
38-
state=t[:state][inds],
39-
action=t[:action][inds],
40-
reward=t[:reward][inds],
41-
terminal=t[:terminal][inds],
42-
next_state=t[:state][inds′],
43-
next_action=t[:state][inds′]
44-
) |> s.transformer
45-
end
46-
47-
function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
48-
if length(t[:state]) == length(t[:terminal]) + 1
49-
pop!(t[:state])
50-
pop!(t[:action])
51-
end
52-
push!(t[:state], x[:state])
53-
push!(t[:action], x[:action])
54-
end
Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
export CircularArraySLARTTraces
22

33
const CircularArraySLARTTraces = Traces{
4-
SLART,
4+
SSLLAART,
55
<:Tuple{
6+
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7+
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
8+
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
69
<:Trace{<:CircularArrayBuffer},
710
<:Trace{<:CircularArrayBuffer},
8-
<:Trace{<:CircularArrayBuffer},
9-
<:Trace{<:CircularArrayBuffer},
10-
<:Trace{<:CircularArrayBuffer}
1111
}
1212
}
1313

14-
1514
function CircularArraySLARTTraces(;
1615
capacity::Int,
1716
state=Int => (),
@@ -26,37 +25,11 @@ function CircularArraySLARTTraces(;
2625
reward_eltype, reward_size = reward
2726
terminal_eltype, terminal_size = terminal
2827

28+
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
29+
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
30+
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
2931
Traces(
30-
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
31-
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
32-
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
3332
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
3433
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3534
)
36-
end
37-
38-
function sample(s::BatchSampler, t::CircularArraySLARTTraces)
39-
inds = rand(s.rng, 1:length(t), s.batch_size)
40-
inds′ = inds .+ 1
41-
(
42-
state=t[:state][inds],
43-
legal_actions_mask=t[:legal_actions_mask][inds],
44-
action=t[:action][inds],
45-
reward=t[:reward][inds],
46-
terminal=t[:terminal][inds],
47-
next_state=t[:state][inds′],
48-
next_legal_actions_mask=t[:legal_actions_mask][inds′],
49-
next_action=t[:state][inds′]
50-
) |> s.transformer
51-
end
52-
53-
function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
54-
if length(t[:state]) == length(t[:terminal]) + 1
55-
pop!(t[:state])
56-
pop!(t[:legal_actions_mask])
57-
pop!(t[:action])
58-
end
59-
push!(t[:state], x[:state])
60-
push!(t[:legal_actions_mask], x[:legal_actions_mask])
61-
push!(t[:action], x[:action])
62-
end
35+
end

src/common/common.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
using CircularArrayBuffers
22

3-
const SA = (:state, :action)
4-
const SLA = (:state, :legal_actions_mask, :action)
3+
const SS = (:state, :next_state)
4+
const LL = (:legal_actions_mask, :next_legal_actions_mask)
5+
const AA = (:action, :next_action)
56
const RT = (:reward, :terminal)
6-
const SART = (:state, :action, :reward, :terminal)
7-
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
8-
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
9-
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)
7+
const SSAART = (SS..., AA..., RT...)
8+
const SSLLAART = (SS..., LL..., AA..., RT...)
109

1110
include("sum_tree.jl")
1211
include("CircularArraySARTTraces.jl")

src/controlers.jl renamed to src/controllers.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,26 @@
1-
export InsertSampleRatioControler, AsyncInsertSampleRatioControler
2-
3-
mutable struct InsertSampleRatioControler
4-
ratio::Float64
5-
threshold::Int
6-
n_inserted::Int
7-
n_sampled::Int
8-
end
1+
export InsertSampleRatioController, AsyncInsertSampleRatioController
92

103
"""
11-
InsertSampleRatioControler(ratio, threshold)
4+
InsertSampleRatioController(;ratio=1., threshold=1)
125
136
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
147
insertings before sampling. The `ratio` balances the number of insertings and
158
the number of samplings.
169
"""
17-
InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0)
10+
Base.@kwdef mutable struct InsertSampleRatioController
11+
ratio::Float64 = 1.0
12+
threshold::Int = 1
13+
n_inserted::Int = 0
14+
n_sampled::Int = 0
15+
end
1816

19-
function on_insert!(c::InsertSampleRatioControler, n::Int)
17+
function on_insert!(c::InsertSampleRatioController, n::Int)
2018
if n > 0
2119
c.n_inserted += n
2220
end
2321
end
2422

25-
function on_sample!(c::InsertSampleRatioControler)
23+
function on_sample!(c::InsertSampleRatioController)
2624
if c.n_inserted >= c.threshold
2725
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
2826
c.n_sampled += 1
@@ -33,7 +31,7 @@ end
3331

3432
#####
3533

36-
mutable struct AsyncInsertSampleRatioControler
34+
mutable struct AsyncInsertSampleRatioController
3735
ratio::Float64
3836
threshold::Int
3937
n_inserted::Int
@@ -42,15 +40,15 @@ mutable struct AsyncInsertSampleRatioControler
4240
ch_out::Channel
4341
end
4442

45-
function AsyncInsertSampleRatioControler(
43+
function AsyncInsertSampleRatioController(
4644
ratio,
4745
threshold,
4846
; ch_in_sz=1,
4947
ch_out_sz=1,
5048
n_inserted=0,
5149
n_sampled=0
5250
)
53-
AsyncInsertSampleRatioControler(
51+
AsyncInsertSampleRatioController(
5452
ratio,
5553
threshold,
5654
n_inserted,

0 commit comments

Comments
 (0)