Skip to content

Commit 76c5fc5

Browse files
Fix buffer size issue
1 parent 39ea7a5 commit 76c5fc5

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

src/common/CircularPrioritizedTraces.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,6 @@ function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
5959
end
6060
end
6161

62-
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
62+
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
63+
64+
capacity(t::CircularPrioritizedTraces) = ReinforcementLearningTrajectories.capacity(t.traces)

src/episodes.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,21 @@ struct PartialNamedTuple{T}
3939
namedtuple::T
4040
end
4141

42+
# Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases
43+
function is_capacity_plus_one(traces::AbstractTraces)
44+
if any(t->t isa MultiplexTraces, traces.traces)
45+
# MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity
46+
return true
47+
elseif traces isa CircularPrioritizedTraces
48+
# CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity
49+
return true
50+
else
51+
false
52+
end
53+
end
54+
4255
function EpisodesBuffer(traces::AbstractTraces)
43-
cap = any(t->t isa MultiplexTraces, traces.traces) ? capacity(traces) + 1 : capacity(traces)
56+
cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces)
4457
@assert isempty(traces) "EpisodesBuffer must be initialized with empty traces."
4558
if !isinf(cap)
4659
legalinds = CircularBuffer{Bool}(cap)

test/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ end
145145
t = Trajectory(
146146
container=CircularPrioritizedTraces(
147147
CircularArraySARTSTraces(
148-
capacity=1000,
148+
capacity=5,
149149
state=Float32 => (4,),
150150
);
151151
default_priority=100.0f0

0 commit comments

Comments
 (0)