Skip to content

Commit 3a05be0

Browse files
committed
add tests for partial namedtuple
1 parent 824765f commit 3a05be0

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

test/episodes.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,94 @@ using Test
8686
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
8787
show(eb);
8888
end
89+
@testset "with PartialNamedTuple" begin
90+
eb = EpisodesBuffer(
91+
CircularArraySARSATTraces(;
92+
capacity=10)
93+
)
94+
#push a first episode l=5
95+
push!(eb, (state = 1,))
96+
@test eb.sampleable_inds[end] == 0
97+
@test eb.episodes_lengths[end] == 0
98+
@test eb.step_numbers[end] == 1
99+
for i = 1:5
100+
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
101+
@test eb.sampleable_inds[end] == 0
102+
@test eb.sampleable_inds[end-1] == 1
103+
@test eb.step_numbers[end] == i + 1
104+
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
105+
end
106+
push!(eb, PartialNamedTuple((action = 6,)))
107+
@test eb.sampleable_inds == [1,1,1,1,1,0]
108+
@test length(eb.traces) == 5
109+
#start new episode of 6 periods.
110+
push!(eb, (state = 7,))
111+
@test eb.sampleable_inds[end] == 0
112+
@test eb.sampleable_inds[end-1] == 0
113+
@test eb.episodes_lengths[end] == 0
114+
@test eb.step_numbers[end] == 1
115+
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
116+
@test eb[6][:reward] == 5 #6 is not a valid index, the reward there is dummy duplicate of previous (5)
117+
ep2_len = 0
118+
for (j,i) = enumerate(8:11)
119+
ep2_len += 1
120+
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
121+
@test eb.sampleable_inds[end] == 0
122+
@test eb.sampleable_inds[end-1] == 1
123+
@test eb.step_numbers[end] == j + 1
124+
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
125+
end
126+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
127+
@test length(eb.traces) == 9 #an action is missing at this stage
128+
#three last steps replace oldest steps in the buffer.
129+
for (i, s) = enumerate(12:13)
130+
ep2_len += 1
131+
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
132+
@test eb.sampleable_inds[end] == 0
133+
@test eb.sampleable_inds[end-1] == 1
134+
@test eb.step_numbers[end] == i + 1 + 4
135+
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
136+
end
137+
push!(eb, PartialNamedTuple((action = 13,)))
138+
@test length(eb.traces) == 10
139+
#episode 1
140+
for (i,s) in enumerate(3:13)
141+
if i in (4, 11)
142+
@test eb.sampleable_inds[i] == 0
143+
continue
144+
else
145+
@test eb.sampleable_inds[i] == 1
146+
end
147+
b = eb[i]
148+
@test b[:state] == b[:action] == b[:reward] == s
149+
@test b[:next_state] == b[:next_action] == s + 1
150+
end
151+
#episode 2
152+
#start a third episode
153+
push!(eb, (state = 14,))
154+
@test eb.sampleable_inds[end] == 0
155+
@test eb.sampleable_inds[end-1] == 0
156+
@test eb.episodes_lengths[end] == 0
157+
@test eb.step_numbers[end] == 1
158+
#push until it reaches it own start
159+
for (i,s) in enumerate(15:26)
160+
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
161+
end
162+
push!(eb, PartialNamedTuple((action = 26,)))
163+
@test eb.sampleable_inds == [fill(true, 10); [false]]
164+
@test eb.episodes_lengths == fill(length(15:26), 11)
165+
@test eb.step_numbers == [3:13;]
166+
step = popfirst!(eb)
167+
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
168+
@test first(eb.step_numbers) == 4
169+
step = pop!(eb)
170+
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8
171+
@test last(eb.step_numbers) == 12
172+
@test size(eb) == size(eb.traces) == (8,)
173+
empty!(eb)
174+
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
175+
show(eb);
176+
end
89177
@testset "with vector traces" begin
90178
eb = EpisodesBuffer(
91179
Traces(;

0 commit comments

Comments
 (0)