@@ -86,6 +86,94 @@ using Test
86
86
@test size (eb) == (0 ,) == size (eb. traces) == size (eb. sampleable_inds) == size (eb. episodes_lengths) == size (eb. step_numbers)
87
87
show (eb);
88
88
end
89
+ @testset " with PartialNamedTuple" begin
90
+ eb = EpisodesBuffer (
91
+ CircularArraySARSATTraces (;
92
+ capacity= 10 )
93
+ )
94
+ # push a first episode l=5
95
+ push! (eb, (state = 1 ,))
96
+ @test eb. sampleable_inds[end ] == 0
97
+ @test eb. episodes_lengths[end ] == 0
98
+ @test eb. step_numbers[end ] == 1
99
+ for i = 1 : 5
100
+ push! (eb, (state = i+ 1 , action = i, reward = i, terminal = false ))
101
+ @test eb. sampleable_inds[end ] == 0
102
+ @test eb. sampleable_inds[end - 1 ] == 1
103
+ @test eb. step_numbers[end ] == i + 1
104
+ @test eb. episodes_lengths[end - i: end ] == fill (i, i+ 1 )
105
+ end
106
+ push! (eb, PartialNamedTuple ((action = 6 ,)))
107
+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ]
108
+ @test length (eb. traces) == 5
109
+ # start new episode of 6 periods.
110
+ push! (eb, (state = 7 ,))
111
+ @test eb. sampleable_inds[end ] == 0
112
+ @test eb. sampleable_inds[end - 1 ] == 0
113
+ @test eb. episodes_lengths[end ] == 0
114
+ @test eb. step_numbers[end ] == 1
115
+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ,0 ]
116
+ @test eb[6 ][:reward ] == 5 # 6 is not a valid index, the reward there is dummy duplicate of previous (5)
117
+ ep2_len = 0
118
+ for (j,i) = enumerate (8 : 11 )
119
+ ep2_len += 1
120
+ push! (eb, (state = i, action = i- 1 , reward = i- 1 , terminal = false ))
121
+ @test eb. sampleable_inds[end ] == 0
122
+ @test eb. sampleable_inds[end - 1 ] == 1
123
+ @test eb. step_numbers[end ] == j + 1
124
+ @test eb. episodes_lengths[end - j: end ] == fill (ep2_len, ep2_len + 1 )
125
+ end
126
+ @test eb. sampleable_inds == [1 ,1 ,1 ,1 ,1 ,0 ,1 ,1 ,1 ,1 ,0 ]
127
+ @test length (eb. traces) == 9 # an action is missing at this stage
128
+ # three last steps replace oldest steps in the buffer.
129
+ for (i, s) = enumerate (12 : 13 )
130
+ ep2_len += 1
131
+ push! (eb, (state = s, action = s- 1 , reward = s- 1 , terminal = false ))
132
+ @test eb. sampleable_inds[end ] == 0
133
+ @test eb. sampleable_inds[end - 1 ] == 1
134
+ @test eb. step_numbers[end ] == i + 1 + 4
135
+ @test eb. episodes_lengths[end - ep2_len: end ] == fill (ep2_len, ep2_len + 1 )
136
+ end
137
+ push! (eb, PartialNamedTuple ((action = 13 ,)))
138
+ @test length (eb. traces) == 10
139
+ # episode 1
140
+ for (i,s) in enumerate (3 : 13 )
141
+ if i in (4 , 11 )
142
+ @test eb. sampleable_inds[i] == 0
143
+ continue
144
+ else
145
+ @test eb. sampleable_inds[i] == 1
146
+ end
147
+ b = eb[i]
148
+ @test b[:state ] == b[:action ] == b[:reward ] == s
149
+ @test b[:next_state ] == b[:next_action ] == s + 1
150
+ end
151
+ # episode 2
152
+ # start a third episode
153
+ push! (eb, (state = 14 ,))
154
+ @test eb. sampleable_inds[end ] == 0
155
+ @test eb. sampleable_inds[end - 1 ] == 0
156
+ @test eb. episodes_lengths[end ] == 0
157
+ @test eb. step_numbers[end ] == 1
158
+ # push until it reaches it own start
159
+ for (i,s) in enumerate (15 : 26 )
160
+ push! (eb, (state = s, action = s- 1 , reward = s- 1 , terminal = false ))
161
+ end
162
+ push! (eb, PartialNamedTuple ((action = 26 ,)))
163
+ @test eb. sampleable_inds == [fill (true , 10 ); [false ]]
164
+ @test eb. episodes_lengths == fill (length (15 : 26 ), 11 )
165
+ @test eb. step_numbers == [3 : 13 ;]
166
+ step = popfirst! (eb)
167
+ @test length (eb) == length (eb. sampleable_inds) - 1 == length (eb. step_numbers) - 1 == length (eb. episodes_lengths) - 1 == 9
168
+ @test first (eb. step_numbers) == 4
169
+ step = pop! (eb)
170
+ @test length (eb) == length (eb. sampleable_inds) - 1 == length (eb. step_numbers) - 1 == length (eb. episodes_lengths) - 1 == 8
171
+ @test last (eb. step_numbers) == 12
172
+ @test size (eb) == size (eb. traces) == (8 ,)
173
+ empty! (eb)
174
+ @test size (eb) == (0 ,) == size (eb. traces) == size (eb. sampleable_inds) == size (eb. episodes_lengths) == size (eb. step_numbers)
175
+ show (eb);
176
+ end
89
177
@testset " with vector traces" begin
90
178
eb = EpisodesBuffer (
91
179
Traces (;
0 commit comments