Skip to content

adding an EpisodesBuffer #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
909c364
recreate episodes
HenriDeh Jun 19, 2023
b926a33
merge
HenriDeh Jun 26, 2023
50587be
implement append traces with multiplex
HenriDeh Jun 28, 2023
85f6e3d
Merge branch 'append' into EpisodeContainer
HenriDeh Jun 28, 2023
9804d45
add Episodesbuffer
HenriDeh Jun 30, 2023
8b4bb20
include
HenriDeh Jun 30, 2023
71bf1c5
add RLTraj.capacity and remove old Episode stuff
HenriDeh Jun 30, 2023
e05243f
new capacity definition
HenriDeh Jun 30, 2023
57163f3
add tests
HenriDeh Jun 30, 2023
c1d6b8d
dummy samples
HenriDeh Jun 30, 2023
3f52521
remove pointer artifact
HenriDeh Jun 30, 2023
e8a740f
remove old episode tests
HenriDeh Jun 30, 2023
fb81276
fix tests
HenriDeh Jun 30, 2023
d73316e
Rework episodes
HenriDeh Jul 3, 2023
258c5e0
typing
HenriDeh Jul 3, 2023
748e876
StatsBase and sampling
HenriDeh Jul 3, 2023
c6007b4
StatsBase
HenriDeh Jul 3, 2023
2586d55
no export Episodes
HenriDeh Jul 3, 2023
0a2424d
test sampling
HenriDeh Jul 3, 2023
9f7ec72
Priorities compat
HenriDeh Jul 4, 2023
2fd8db2
fix tests
HenriDeh Jul 4, 2023
178ef24
Autowrap and normalization compat
HenriDeh Jul 6, 2023
d180a8d
bump version
HenriDeh Jul 6, 2023
8ccbf95
update readme
HenriDeh Jul 6, 2023
2571796
update readme
HenriDeh Jul 6, 2023
767f764
Merge branch 'EpisodeContainer' of https://github.com/JuliaReinforcem…
HenriDeh Jul 6, 2023
63eb4c9
increase codecov
HenriDeh Jul 6, 2023
3a58315
rename to SARSA trace
HenriDeh Jul 6, 2023
6fffbd1
partial insertion forcing
HenriDeh Jul 6, 2023
3495151
add trajectory key fetcher
HenriDeh Jul 6, 2023
01e8117
add a comment
HenriDeh Jul 6, 2023
6dbe801
fix sarsa capacity
HenriDeh Jul 6, 2023
2b2eb4a
move partial
HenriDeh Jul 6, 2023
824765f
deal with partialnt indexation
HenriDeh Jul 6, 2023
3a05be0
add tests for partial namedtuple
HenriDeh Jul 6, 2023
ca511f7
Test without partialnt
HenriDeh Jul 6, 2023
65978c9
sample namespace
HenriDeh Jul 6, 2023
8484578
drop commented code
HenriDeh Jul 7, 2023
b81d18f
rename to SARST
HenriDeh Jul 7, 2023
8205e9e
rename traces
HenriDeh Jul 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.1.10"
version = "0.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Adapt = "3"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ The relationship of several concepts provided in this package:
┌───────────────────────────────────┐
│ Trajectory │
│ ┌───────────────────────────────┐ │
│ │ AbstractTraces │ │
│ │ EpisodesBuffer wrapping a | |
| | AbstractTraces │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_A => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
Expand Down Expand Up @@ -61,8 +62,7 @@ julia> for batch in t
- `Traces`
- `MultiplexTraces`
- `CircularSARTTraces`
- `Episode`
- `Episodes`
- `NormalizedTraces`

**Samplers**

Expand Down
3 changes: 2 additions & 1 deletion src/ReinforcementLearningTrajectories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module ReinforcementLearningTrajectories

const RLTrajectories = ReinforcementLearningTrajectories
export RLTrajectories
import StatsBase

include("patch.jl")
include("traces.jl")
include("common/common.jl")
include("episodes.jl")
include("samplers.jl")
include("controllers.jl")
include("trajectory.jl")
include("normalization.jl")

end
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
export CircularArraySARTTraces
export CircularArraySARTSATraces

import CircularArrayBuffers
import CircularArrayBuffers.CircularArrayBuffer

const CircularArraySARTTraces = Traces{
const CircularArraySARTSATraces = Traces{
SS′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:CircularArrayBuffer}},
Expand All @@ -12,7 +12,7 @@ const CircularArraySARTTraces = Traces{
}
}

function CircularArraySARTTraces(;
function CircularArraySARTSATraces(;
capacity::Int,
state=Int => (),
action=Int => (),
Expand All @@ -24,12 +24,12 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTTraces) = CircularArrayBuffers.capacity(t.traces[end])
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
35 changes: 35 additions & 0 deletions src/common/CircularArraySARTSTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
export CircularArraySARTSTraces

import CircularArrayBuffers.CircularArrayBuffer

const CircularArraySARTSTraces = Traces{
SS′ART,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
}

function CircularArraySARTSTraces(;
capacity::Int,
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())

state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
Traces(
action = CircularArrayBuffer{action_eltype}(action_size..., capacity),
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
2 changes: 2 additions & 0 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ function CircularArraySLARTTraces(;
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
2 changes: 1 addition & 1 deletion src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export CircularPrioritizedTraces

using CircularArrayBuffers: capacity, CircularVectorBuffer
using CircularArrayBuffers: CircularVectorBuffer

struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
keys::CircularVectorBuffer{Int,Vector{Int}}
Expand Down
5 changes: 2 additions & 3 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
export SS′, LL′, AA′, RT, SS′ART, SS′AA′RT, SS′L′ART, SS′LL′AA′RT

using CircularArrayBuffers

const SS′ = (:state, :next_state)
const LL′ = (:legal_actions_mask, :next_legal_actions_mask)
const AA′ = (:action, :next_action)
Expand All @@ -12,7 +10,8 @@ const SS′L′ART = (SS′..., :next_legal_actions_mask, :action, RT...)
const SS′LL′AA′RT = (SS′..., LL′..., AA′..., RT...)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySARTSTraces.jl")
include("CircularArraySARTSATraces.jl")
include("CircularArraySLARTTraces.jl")
include("CircularPrioritizedTraces.jl")
include("ElasticArraySARTTraces.jl")
138 changes: 138 additions & 0 deletions src/episodes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
export EpisodesBuffer, PartialNamedTuple
import DataStructures.CircularBuffer

"""
EpisodesBuffer(traces::AbstractTraces)

Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
`EpisodesBuffer` tracks the indexes of the `traces` object that belong to the same episodes.
To that end, it stores
1. an vector `sampleable_inds` of Booleans that determine whether an index in Traces is legally sampleable
(i.e., it is not the index of a last state of an episode);
2. a vector `episodes_lengths` that contains the total duration of the episode that each step belong to;
3. an vector `step_numbers` that contains the index within the episode of the corresponding step.

This information is used to correctly sample the traces. For example, if we have an episode that lasted 10 steps, the buffer stores 11 states. sampleable_inds[i] will
be true for the index of the first ten steps, and 0 at the index of state number 11. episodes_lengths will be 11
consecutive 10s (the episode saw 11 states but 10 steps occured). step_numbers will be 1 to 11.

If `traces` is a capacitated buffer, such as a CircularArraySARTSTraces, then these three vectors will also be circular.

EpisodesBuffer assumes that individual transitions are `push!`ed. Appending is not yet supported.
"""

mutable struct EpisodesBuffer{names, E, T<:AbstractTraces{names, E},B,S} <: AbstractTraces{names,E}
traces::T
sampleable_inds::S
step_numbers::B
episodes_lengths::B
end

"""
PartialNamedTuple(::NamedTuple)

Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
ignore the fact that this is a partial insertion. Used at the end of an episode to
complete multiplex traces before moving to the next episode.
"""
struct PartialNamedTuple{T}
namedtuple::T
end

function EpisodesBuffer(traces::AbstractTraces)
cap = any(t->t isa MultiplexTraces, traces.traces) ? capacity(traces) + 1 : capacity(traces)
@assert isempty(traces) "EpisodesBuffer must be initialized with empty traces."
if !isinf(cap)
legalinds = CircularBuffer{Bool}(cap)
step_numbers = CircularBuffer{Int}(cap)
eplengths = deepcopy(step_numbers)
EpisodesBuffer(traces, legalinds, step_numbers, eplengths)
else
legalinds = BitVector()
step_numbers = Vector{Int}()
eplengths = deepcopy(step_numbers)
EpisodesBuffer(traces, legalinds, step_numbers, eplengths)
end
end

Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...)
Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...)
Base.size(es::EpisodesBuffer) = size(es.traces)
Base.length(es::EpisodesBuffer) = length(es.traces)
Base.keys(es::EpisodesBuffer) = keys(es.traces)
Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces)
function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names}
s = nameof(typeof(eb))
t = eb.traces
println(io, "$s containing")
show(io, m::MIME"text/plain", t)
end

ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)

function fill_multiplex(es::EpisodesBuffer)
for trace in es.traces.traces
if !(trace isa MultiplexTraces)
push!(trace, last(trace)) #push a duplicate of last element as a dummy element, should never be sampled.
end
end
end
function fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces})
for trace in es.traces.traces.traces
if !(trace isa MultiplexTraces)
push!(trace, last(trace)) #push a duplicate of last element as a dummy element, should never be sampled.
end
end
end

function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.traces, xs)
partial = ispartial_insert(eb, xs)
if length(eb.traces) == 0 && partial #first push should be partial
push!(eb.step_numbers, 1)
push!(eb.episodes_lengths, 0)
push!(eb.sampleable_inds, 0)
elseif !partial #typical inserting
if length(eb.traces) < length(eb) && length(eb) > 2 #case when PartialNamedTuple is used. Steps are indexable one step later
eb.sampleable_inds[end-1] = 1
else #case when we don't, length of traces and eb will match.
eb.sampleable_inds[end] = 1 #previous step is now indexable
end
push!(eb.sampleable_inds, 0) #this one is no longer
ep_length = last(eb.step_numbers)
push!(eb.episodes_lengths, ep_length)
startidx = max(1,length(eb.step_numbers) - last(eb.step_numbers))
eb.episodes_lengths[startidx:end] .= ep_length
push!(eb.step_numbers, ep_length + 1)
elseif partial
fill_multiplex(eb)
eb.sampleable_inds[end] = 0 #previous step is not indexable because it contains the last state
push!(eb.sampleable_inds, 0) #this one isn't either
push!(eb.step_numbers, 1)
push!(eb.episodes_lengths, 0)
end
return nothing
end

function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(es::EpisodesBuffer)
$f(es.episodes_lengths)
$f(es.sampleable_inds)
$f(es.step_numbers)
$f(es.traces)
end
end

function Base.empty!(es::EpisodesBuffer)
empty!(es.traces)
empty!(es.episodes_lengths)
empty!(es.sampleable_inds)
empty!(es.step_numbers)
end
19 changes: 15 additions & 4 deletions src/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ arrays (see [`array_normalizer`](@ref)).

# Examples
```
t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
t = CircularArraySARTSTraces(capacity = 10, state = Float64 => (5,))
nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
# :next_state will also be normalized.
traj = Trajectory(
Expand Down Expand Up @@ -180,7 +180,7 @@ function Base.show(io::IO, ::MIME"text/plain", t::NormalizedTraces{names,T}) whe
end
end

@forward NormalizedTraces.traces Base.length, Base.size, Base.lastindex, Base.firstindex, Base.getindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.parent
@forward NormalizedTraces.traces Base.length, Base.size, Base.lastindex, Base.firstindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.parent

for f in (:push!, :pushfirst!, :append!, :prepend!)
@eval function Base.$f(nt::NormalizedTraces, x::T) where T
Expand All @@ -191,8 +191,19 @@ for f in (:push!, :pushfirst!, :append!, :prepend!)
end
end

function sample(s::BatchSampler, nt::NormalizedTraces, names)
inds = rand(s.rng, 1:length(nt), s.batch_size)
function StatsBase.sample(s::BatchSampler, nt::NormalizedTraces, names, weights = StatsBase.UnitWeights{Int}(length(nt)))
inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batch_size)
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
NamedTuple{names}(collect(maybe_normalize(nt[x][inds], x)) for x in names)
end

function Base.getindex(nt::NormalizedTraces, inds)
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
NamedTuple{keys(nt.traces)}(collect(maybe_normalize(nt.traces[x][inds], x)) for x in keys(nt.traces))
end

function Base.getindex(nt::NormalizedTraces, s::Symbol)
getindex(nt.traces, s)
end

ispartial_insert(traces::NormalizedTraces, xs) = ispartial_insert(traces.traces, xs)
Loading