Skip to content

Commit f222bad

Browse files
Merge pull request #66 from JuliaReinforcementLearning/jpsl/elastic
Add Missing Elastic Methods and tests
2 parents 9625f16 + 4895a41 commit f222bad

File tree

12 files changed

+339
-38
lines changed

12 files changed

+339
-38
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.9'
2221
- '1'
22+
- '^1.11.0-alpha'
2323
- 'nightly'
2424
os:
2525
- ubuntu-latest

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.3.7"
3+
version = "0.4"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/common/ElasticArraySARTTraces.jl renamed to src/common/ElasticArraySARTSATraces.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
export ElasticArraySARTTraces
1+
export ElasticArraySARTSATraces
22

3-
using ElasticArrays: ElasticArray, resize_lastdim!
4-
5-
const ElasticArraySARTTraces = Traces{
3+
const ElasticArraySARTSATraces = Traces{
64
SS′AA′RT,
75
<:Tuple{
86
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
@@ -12,7 +10,7 @@ const ElasticArraySARTTraces = Traces{
1210
}
1311
}
1412

15-
function ElasticArraySARTTraces(;
13+
function ElasticArraySARTSATraces(;
1614
state=Int => (),
1715
action=Int => (),
1816
reward=Float32 => (),
@@ -31,10 +29,3 @@ function ElasticArraySARTTraces(;
3129
)
3230
end
3331

34-
#####
35-
# extensions for ElasticArrays
36-
#####
37-
38-
Base.push!(a::ElasticArray, x) = append!(a, x)
39-
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
40-
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)

src/common/ElasticArraySARTSTraces.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
export ElasticArraySARTSTraces
2+
3+
const ElasticArraySARTSTraces = Traces{
4+
SS′ART,
5+
<:Tuple{
6+
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
7+
<:Trace{<:ElasticArray},
8+
<:Trace{<:ElasticArray},
9+
<:Trace{<:ElasticArray},
10+
}
11+
}
12+
13+
function ElasticArraySARTSTraces(;
14+
state=Int => (),
15+
action=Int => (),
16+
reward=Float32 => (),
17+
terminal=Bool => ())
18+
19+
state_eltype, state_size = state
20+
action_eltype, action_size = action
21+
reward_eltype, reward_size = reward
22+
terminal_eltype, terminal_size = terminal
23+
24+
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
25+
Traces(
26+
action = ElasticArray{action_eltype}(undef, action_size..., 0),
27+
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
28+
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
29+
)
30+
end

src/common/ElasticArraySLARTTraces.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
export ElasticArraySLARTTraces
2+
3+
const ElasticArraySLARTTraces = Traces{
4+
SS′LL′AA′RT,
5+
<:Tuple{
6+
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
7+
<:MultiplexTraces{LL′,<:Trace{<:ElasticArray}},
8+
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
9+
<:Trace{<:ElasticArray},
10+
<:Trace{<:ElasticArray},
11+
}
12+
}
13+
14+
function ElasticArraySLARTTraces(;
15+
capacity::Int,
16+
state=Int => (),
17+
legal_actions_mask=Bool => (),
18+
action=Int => (),
19+
reward=Float32 => (),
20+
terminal=Bool => ()
21+
)
22+
state_eltype, state_size = state
23+
action_eltype, action_size = action
24+
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
25+
reward_eltype, reward_size = reward
26+
terminal_eltype, terminal_size = terminal
27+
28+
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
29+
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
30+
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
31+
Traces(
32+
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
33+
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
34+
)
35+
end

src/common/common.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,7 @@ include("CircularArraySARTSTraces.jl")
1414
include("CircularArraySARTSATraces.jl")
1515
include("CircularArraySLARTTraces.jl")
1616
include("CircularPrioritizedTraces.jl")
17-
include("ElasticArraySARTTraces.jl")
17+
include("common_elastic_array.jl")
18+
include("ElasticArraySARTSTraces.jl")
19+
include("ElasticArraySARTSATraces.jl")
20+
include("ElasticArraySLARTTraces.jl")

src/common/common_elastic_array.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using ElasticArrays: ElasticArray, resize_lastdim!
2+
3+
#####
4+
# extensions for ElasticArrays
5+
#####
6+
7+
Base.push!(a::ElasticArray, x) = append!(a, x)
8+
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
9+
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)

src/episodes.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export EpisodesBuffer, PartialNamedTuple
22
import DataStructures.CircularBuffer
3+
using ElasticArrays: ElasticArray, ElasticVector
34

45
"""
56
EpisodesBuffer(traces::AbstractTraces)
@@ -68,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces)
6869
end
6970
end
7071

71-
Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...)
72-
Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...)
73-
Base.size(es::EpisodesBuffer) = size(es.traces)
74-
Base.length(es::EpisodesBuffer) = length(es.traces)
75-
Base.keys(es::EpisodesBuffer) = keys(es.traces)
76-
Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces)
72+
function Base.getindex(es::EpisodesBuffer, idx::Int...)
73+
@boundscheck all(es.sampleable_inds[idx...])
74+
getindex(es.traces, idx...)
75+
end
76+
77+
function Base.getindex(es::EpisodesBuffer, idx...)
78+
getindex(es.traces, idx...)
79+
end
80+
81+
Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
82+
Base.size(eb::EpisodesBuffer) = size(eb.traces)
83+
Base.length(eb::EpisodesBuffer) = length(eb.traces)
84+
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
85+
Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces)
7786
function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names}
7887
s = nameof(typeof(eb))
7988
t = eb.traces
@@ -82,14 +91,16 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where
8291
end
8392

8493
ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
85-
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
94+
ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs)
8695
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)
8796

8897
function pad!(trace::Trace)
8998
pad!(trace.parent)
9099
return nothing
91100
end
92101

102+
pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T))
103+
pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T))
93104
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
94105
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
95106

@@ -123,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
123134
return :($ex)
124135
end
125136

126-
fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
137+
fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)
127138

128-
fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
139+
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)
129140

130141
function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
131142
push!(eb.traces, xs)
@@ -162,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
162173
end
163174

164175
for f in (:pop!, :popfirst!)
165-
@eval function Base.$f(es::EpisodesBuffer)
166-
$f(es.episodes_lengths)
167-
$f(es.sampleable_inds)
168-
$f(es.step_numbers)
169-
$f(es.traces)
176+
@eval function Base.$f(eb::EpisodesBuffer)
177+
$f(eb.episodes_lengths)
178+
$f(eb.sampleable_inds)
179+
$f(eb.step_numbers)
180+
$f(eb.traces)
170181
end
171182
end
172183

173-
function Base.empty!(es::EpisodesBuffer)
174-
empty!(es.traces)
175-
empty!(es.episodes_lengths)
176-
empty!(es.sampleable_inds)
177-
empty!(es.step_numbers)
184+
function Base.empty!(eb::EpisodesBuffer)
185+
empty!(eb.traces)
186+
empty!(eb.episodes_lengths)
187+
empty!(eb.sampleable_inds)
188+
empty!(eb.step_numbers)
178189
end

src/traces.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces
33
import MacroTools: @forward
44

55
import CircularArrayBuffers.CircularArrayBuffer
6+
using ElasticArrays: ElasticArray
67
import Adapt
78

89
#####
@@ -55,6 +56,7 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
5556
capacity(t::AbstractTrace) = ReinforcementLearningTrajectories.capacity(t.parent)
5657
capacity(t::CircularArrayBuffer) = CircularArrayBuffers.capacity(t)
5758
capacity(::AbstractVector) = Inf
59+
capacity(::ElasticArray) = Inf
5860

5961
#####
6062

test/common.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,15 @@ end
9494
@test batch.terminal == Bool[0, 0, 0] |> gpu
9595
end
9696

97-
@testset "ElasticArraySARTTraces" begin
98-
t = ElasticArraySARTTraces(;
97+
@testset "ElasticArraySARTSTraces" begin
98+
t = ElasticArraySARTSTraces(;
9999
state=Float32 => (2, 3),
100100
action=Int => (),
101101
reward=Float32 => (),
102102
terminal=Bool => ()
103103
)
104104

105-
@test t isa ElasticArraySARTTraces
105+
@test t isa ElasticArraySARTSTraces
106106

107107
push!(t, (state=ones(Float32, 2, 3), action=1))
108108
push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2))
@@ -185,4 +185,4 @@ end
185185

186186
eb[:priority, [1, 2]] = [0, 0]
187187
@test eb[:priority] == [zeros(2);ones(8)]
188-
end
188+
end

0 commit comments

Comments
 (0)