Skip to content

Commit 1bb46a9

Browse files
authored
Merge pull request #14 from findmyway/add_multiplex_traces
Unify the definition of `AbstractTraces`
2 parents ffc8576 + d609559 commit 1bb46a9

17 files changed

+713
-305
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
coverage/
12
*.jl.*.cov
23
*.jl.cov
34
*.jl.mem
45
/Manifest.toml
6+
7+
.DS_Store

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
88
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
99
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
1112
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
1213

1314
[compat]

README.md

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,78 @@
66

77
## Design
88

9-
A typical example of `Trajectory`:
9+
The relationship of several concepts provided in this package:
1010

11-
![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)
11+
```
12+
┌───────────────────────────────────┐
13+
│ Trajectory │
14+
│ ┌───────────────────────────────┐ │
15+
│ │ AbstractTraces │ │
16+
│ │ ┌───────────────┐ │ │
17+
│ │ :trace_A => │ AbstractTrace │ │ │
18+
│ │ └───────────────┘ │ │
19+
│ │ │ │
20+
│ │ ┌───────────────┐ │ │
21+
│ │ :trace_B => │ AbstractTrace │ │ │
22+
│ │ └───────────────┘ │ │
23+
│ │ ... ... │ │
24+
│ └───────────────────────────────┘ │
25+
│ ┌───────────┐ │
26+
│ │ Sampler │ │
27+
│ └───────────┘ │
28+
│ ┌────────────┐ │
29+
│ │ Controller │ │
30+
│ └────────────┘ │
31+
└───────────────────────────────────┘
32+
```
33+
34+
## `Trajectory`
35+
36+
A `Trajectory` contains 3 parts:
1237

13-
Exported APIs are:
38+
- A `container` to store data. (Usually an `AbstractTraces`)
39+
- A `sampler` to determine how to sample a batch from `container`
40+
- A `controller` to decide when to sample a new batch from the `container`
41+
42+
Typical usage:
1443

1544
```julia
16-
push!(trajectory; [trace_name=value]...)
17-
append!(trajectory; [trace_name=value]...)
45+
julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3));
46+
47+
julia> for i in 1:5
48+
push!(t, (a=i, b=iseven(i)))
49+
end
1850

19-
for sample in trajectory
20-
# consume samples from the trajectory
21-
end
51+
julia> for batch in t
52+
println(batch)
53+
end
54+
(a = [4, 5, 1], b = Bool[1, 0, 0])
55+
(a = [3, 2, 4], b = Bool[0, 1, 1])
56+
(a = [4, 1, 2], b = Bool[1, 0, 1])
2257
```
2358

24-
A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.
59+
**Traces**
60+
61+
- `Traces`
62+
- `MultiplexTraces`
63+
- `CircularSARTTraces`
64+
- `Episode`
65+
- `Episodes`
66+
67+
**Samplers**
68+
69+
- `BatchSampler`
70+
- `MetaSampler`
71+
- `MultiBatchSampler`
72+
73+
**Controllers**
74+
75+
- `InsertSampleRatioController`
76+
- `InsertSampleController`
77+
- `AsyncInsertSampleRatioController`
78+
79+
80+
Please refer tests for common usage. (TODO: generate docs and add links to above data structures)
2581

2682
## Acknowledgement
2783

src/Trajectories.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
module Trajectories
22

3+
include("patch.jl")
4+
5+
include("traces.jl")
36
include("samplers.jl")
47
include("controllers.jl")
5-
include("traces.jl")
6-
include("episodes.jl")
78
include("trajectory.jl")
8-
include("rendering.jl")
99
include("common/common.jl")
1010

1111
end

src/common/CircularArraySARTTraces.jl

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
export CircularArraySARTTraces
22

33
const CircularArraySARTTraces = Traces{
4-
SART,
4+
SSAART,
55
<:Tuple{
6+
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7+
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
68
<:Trace{<:CircularArrayBuffer},
79
<:Trace{<:CircularArrayBuffer},
8-
<:Trace{<:CircularArrayBuffer},
9-
<:Trace{<:CircularArrayBuffer}
1010
}
1111
}
1212

13-
1413
function CircularArraySARTTraces(;
1514
capacity::Int,
1615
state=Int => (),
@@ -23,32 +22,10 @@ function CircularArraySARTTraces(;
2322
reward_eltype, reward_size = reward
2423
terminal_eltype, terminal_size = terminal
2524

25+
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
26+
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
2627
Traces(
27-
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
28-
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
2928
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
3029
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3130
)
3231
end
33-
34-
function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
35-
inds = rand(s.rng, 1:length(t), s.batch_size)
36-
inds′ = inds .+ 1
37-
(
38-
state=t[:state][inds],
39-
action=t[:action][inds],
40-
reward=t[:reward][inds],
41-
terminal=t[:terminal][inds],
42-
next_state=t[:state][inds′],
43-
next_action=t[:state][inds′]
44-
) |> s.transformer
45-
end
46-
47-
function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
48-
if length(t[:state]) == length(t[:terminal]) + 1
49-
pop!(t[:state])
50-
pop!(t[:action])
51-
end
52-
push!(t[:state], x[:state])
53-
push!(t[:action], x[:action])
54-
end
Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
export CircularArraySLARTTraces
22

33
const CircularArraySLARTTraces = Traces{
4-
SLART,
4+
SSLLAART,
55
<:Tuple{
6+
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7+
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
8+
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
69
<:Trace{<:CircularArrayBuffer},
710
<:Trace{<:CircularArrayBuffer},
8-
<:Trace{<:CircularArrayBuffer},
9-
<:Trace{<:CircularArrayBuffer},
10-
<:Trace{<:CircularArrayBuffer}
1111
}
1212
}
1313

14-
1514
function CircularArraySLARTTraces(;
1615
capacity::Int,
1716
state=Int => (),
@@ -26,37 +25,11 @@ function CircularArraySLARTTraces(;
2625
reward_eltype, reward_size = reward
2726
terminal_eltype, terminal_size = terminal
2827

28+
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
29+
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
30+
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
2931
Traces(
30-
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
31-
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
32-
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
3332
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
3433
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3534
)
36-
end
37-
38-
function sample(s::BatchSampler, t::CircularArraySLARTTraces)
39-
inds = rand(s.rng, 1:length(t), s.batch_size)
40-
inds′ = inds .+ 1
41-
(
42-
state=t[:state][inds],
43-
legal_actions_mask=t[:legal_actions_mask][inds],
44-
action=t[:action][inds],
45-
reward=t[:reward][inds],
46-
terminal=t[:terminal][inds],
47-
next_state=t[:state][inds′],
48-
next_legal_actions_mask=t[:legal_actions_mask][inds′],
49-
next_action=t[:state][inds′]
50-
) |> s.transformer
51-
end
52-
53-
function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
54-
if length(t[:state]) == length(t[:terminal]) + 1
55-
pop!(t[:state])
56-
pop!(t[:legal_actions_mask])
57-
pop!(t[:action])
58-
end
59-
push!(t[:state], x[:state])
60-
push!(t[:legal_actions_mask], x[:legal_actions_mask])
61-
push!(t[:action], x[:action])
62-
end
35+
end

src/common/common.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
using CircularArrayBuffers
22

3-
const SA = (:state, :action)
4-
const SLA = (:state, :legal_actions_mask, :action)
3+
const SS = (:state, :next_state)
4+
const LL = (:legal_actions_mask, :next_legal_actions_mask)
5+
const AA = (:action, :next_action)
56
const RT = (:reward, :terminal)
6-
const SART = (:state, :action, :reward, :terminal)
7-
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
8-
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
9-
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)
7+
const SSAART = (SS..., AA..., RT...)
8+
const SSLLAART = (SS..., LL..., AA..., RT...)
109

1110
include("sum_tree.jl")
1211
include("CircularArraySARTTraces.jl")

src/episodes.jl

Lines changed: 0 additions & 107 deletions
This file was deleted.

src/patch.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import MLUtils
2+
3+
MLUtils.batch(x::AbstractArray{<:Number}) = x

0 commit comments

Comments
 (0)