Skip to content

Commit 8256e76

Browse files
committed
2 parents e036f63 + 7673022 commit 8256e76

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

src/ReinforcementLearningZoo/src/algorithms/bootstrapping/retrace.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export retrace
77
function retrace_operator(qnetwork, policy, batch, γ, λ)
88
s = batch[:state] |> send_to_device(qnetwork)
99
a = batch[:action] |> send_to_device(qnetwork)
10-
behavior_log_probs = batch[:action_log_problog_prob] |> send_to_device(qnetwork)
10+
behavior_log_probs = batch[:action_log_prob] |> send_to_device(qnetwork)
1111
r = batch[:reward] |> send_to_device(qnetwork)
1212
t = last.(batch[:terminal]) |> send_to_device(qnetwork)
1313
ns = batch[:next_state] |> send_to_device(qnetwork)

test/operators.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import ReinforcementLearningCore
2+
@testset "retrace" begin
3+
batch = (state= [[1 2 3], [10 11 12]],
4+
action = [[1 2 3],[10 11 12]],
5+
log_prob = [log.([0.2,0.2,0.2]), log.([0.1,0.1,0.1])],
6+
reward = [[1f0,2f0,3f0],[10f0,11f0,12f0]],
7+
terminal= [[0,0,1], [0,0,0]],
8+
next_state = [[2 3 4],[11 12 13]])
9+
10+
#define a fake policy where a = x and that returns the same log probabilities always
11+
policy(x; is_sampling = true, is_return_log_prob = false) = identity(x)
12+
policy(s,a) = log.([0.1/2,0.1/3,0.1/4])#both samples have the same current logprobs
13+
qnetwork(x, args...) = x[1, :] #the value of a state is its number
14+
λ, γ = 0.9, 0.99
15+
ReinforcementLearningCore.target(qnetwork) = qnetwork
16+
ops = retrace_operator(qnetwork, policy, batch, γ, λ)
17+
#handmade calculation of the correct ops
18+
op1 = 1*0.9*1/4*(1+0.99*2-1) + 0.99*0.9^2*1/4*1/6*(2+0.99*3-2) + 0.99^2*0.9^3*1/4*1/6*1/8*(3+0.99*4*1-3)
19+
op2 = 1*0.9*0.5*(10+0.99*11-10) + 0.99*0.9^2*0.5*1/3*(11+0.99*12-11) + 0.99^2*0.9^3*0.5*0.33*0.25*(12+0.99*13*0-12)
20+
println("test")
21+
@test ops == [op1, op2]
22+
end

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
2-
using ReinforcementLearning
2+
using ReinforcementLearningZoo
33

4-
@testset "ReinforcementLearning" begin
4+
@testset "ReinforcementLearningZoo" begin
5+
include("operators.jl")
56
end

0 commit comments

Comments
 (0)