Skip to content

Commit ca99dd1

Browse files
Add test for bug
1 parent a36ddc3 commit ca99dd1

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

test/samplers.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ end
138138
end
139139
#! format: on
140140

141-
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
141+
@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin
142142
n=1
143143
γ=0.99f0
144144

@@ -169,3 +169,33 @@ end
169169
b = RLTrajectories.StatsBase.sample(t)
170170
@test haskey(b, :priority)
171171
end
172+
173+
174+
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
175+
n=1
176+
γ=0.99f0
177+
178+
t = Trajectory(
179+
container=CircularArraySARTSTraces(
180+
capacity=5,
181+
state=Float32 => (4,),
182+
),
183+
sampler=NStepBatchSampler{SS′ART}(
184+
n=n,
185+
γ=γ,
186+
batch_size=32,
187+
),
188+
controller=InsertSampleRatioController(
189+
threshold=100,
190+
n_inserted=-1
191+
)
192+
)
193+
194+
push!(t, (state = 1, action = true))
195+
for i = 1:9
196+
push!(t, (state = i+1, action = true, reward = i, terminal = false))
197+
end
198+
199+
b = RLTrajectories.StatsBase.sample(t)
200+
@test haskey(b, :priority)
201+
end

0 commit comments

Comments
 (0)