Skip to content

Commit 2d1fe56

Browse files
committed
add retrace
1 parent a8ae878 commit 2d1fe56

File tree

4 files changed

+94
-2
lines changed

4 files changed

+94
-2
lines changed

src/ReinforcementLearningZoo/src/algorithms/algorithms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
include("dqns/dqns.jl")
33
# include("policy_gradient/policy_gradient.jl")
44
include("policy_gradient/policy_gradient.jl")
5+
include("bootstrapping/retrace.jl")
56
# include("searching/searching.jl")
67
# include("cfr/cfr.jl")
78
# include("offline_rl/offline_rl.jl")
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
export retrace
2+
#Note: the speed of this operator may be improved by batching states and actions
3+
#to make single calls to Q, instead of one per batch sample. However this operator
4+
#is no backpropagated through and its computation will typically represent a minor
5+
#fraction of the runtime of a deep RL algorithm.
6+
7+
function retrace_operator(qnetwork, policy, batch, γ, λ)
8+
s = batch[:state] |> send_to_device(qnetwork)
9+
a = batch[:action] |> send_to_device(qnetwork)
10+
behavior_log_probs = batch[:log_prob] |> send_to_device(qnetwork)
11+
r = batch[:reward] |> send_to_device(qnetwork)
12+
t = last.(batch[:terminal]) |> send_to_device(qnetwork)
13+
ns = batch[:next_state] |> send_to_device(qnetwork)
14+
na = map(ns) do ns
15+
policy(ns, is_sampling = true, is_return_log_prob = false)
16+
end
17+
states = map(s,ns) do s, ns #concatenates all states, including the last state to compute deltas with the target Q
18+
cat(s,last(eachslice(ns, dims = ndims(ns))),dims=ndims(s))
19+
end
20+
actions = map(a, na) do a, na
21+
cat(a,last(eachslice(na, dims = ndims(na))),dims=ndims(a))
22+
end
23+
24+
current_log_probs = map(s,a) do s, a
25+
policy(s,a)
26+
end
27+
28+
traces = map(current_log_probs, behavior_log_probs) do p,m
29+
@. λ*min(1, exp(p - m))
30+
end
31+
is_ratios = cumprod.(traces) #batchsized vector [[c1,c2,...,ct],[c1,c2,...,ct],...]
32+
33+
Qp = target(qnetwork)
34+
35+
δs = map(states, actions, r, t) do s, a, r, t
36+
q = vec(Qp(vcat(s,a)))
37+
q[end] *= t
38+
r .+ q[2:end].*γ .- q[1:end-1]
39+
end
40+
41+
ops = map(is_ratios, δs) do ratios, deltas
42+
γs = γ .^ (0:(length(deltas)-1))
43+
sum(γs .* ratios .* deltas)
44+
end
45+
46+
return ops #batchsize vector of operator
47+
end
48+
49+
batch = (state= [[1 2 3], [10 11 12]],
50+
action = [[1 2 3],[10 11 12]],
51+
log_prob = [log.([0.2,0.2,0.2]), log.([0.1,0.1,0.1])],
52+
reward = [[1f0,2f0,3f0],[10f0,11f0,12f0]],
53+
terminal= [[0,0,1], [0,0,0]],
54+
next_state = [[2 3 4],[11 12 13]])
55+
56+
current_log_probs = [log.([0.1/2,0.1/3,0.1/4]) for i in 1:2]
57+
policy(x; is_sampling = true, is_return_log_prob = false) = identity(x)
58+
policy(s,a) = current_log_probs[2]
59+
qnetwork(x, args...) = x[1, :]
60+
λ, γ = 0.9, 0.99
61+
target(Qp) = Qp
62+
send_to_device(x) = identity
63+
retrace_operator(qnetwork, policy, batch, γ, λ)
64+
65+
66+
#calculer à la main pour batch[2]
67+
1*0.9*1/4*(1+0.99*2-1) + 0.99*0.9^2*1/4*1/6*(2+0.99*3-2) + 0.99^2*0.9^3*1/4*1/6*1/8*(3+0.99*4*1-3)
68+
1*0.9*0.5*(10+0.99*11-10) + 0.99*0.9^2*0.5*1/3*(11+0.99*12-11) + 0.99^2*0.9^3*0.5*0.33*0.25*(12+0.99*13*0-12)

test/operators.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import ReinforcementLearningCore
2+
@testset "retrace" begin
3+
batch = (state= [[1 2 3], [10 11 12]],
4+
action = [[1 2 3],[10 11 12]],
5+
log_prob = [log.([0.2,0.2,0.2]), log.([0.1,0.1,0.1])],
6+
reward = [[1f0,2f0,3f0],[10f0,11f0,12f0]],
7+
terminal= [[0,0,1], [0,0,0]],
8+
next_state = [[2 3 4],[11 12 13]])
9+
10+
#define a fake policy where a = x and that returns the same log probabilities always
11+
policy(x; is_sampling = true, is_return_log_prob = false) = identity(x)
12+
policy(s,a) = log.([0.1/2,0.1/3,0.1/4])#both samples have the same current logprobs
13+
qnetwork(x, args...) = x[1, :] #the value of a state is its number
14+
λ, γ = 0.9, 0.99
15+
ReinforcementLearningCore.target(qnetwork) = qnetwork
16+
ops = retrace_operator(qnetwork, policy, batch, γ, λ)
17+
#handmade calculation of the correct ops
18+
op1 = 1*0.9*1/4*(1+0.99*2-1) + 0.99*0.9^2*1/4*1/6*(2+0.99*3-2) + 0.99^2*0.9^3*1/4*1/6*1/8*(3+0.99*4*1-3)
19+
op2 = 1*0.9*0.5*(10+0.99*11-10) + 0.99*0.9^2*0.5*1/3*(11+0.99*12-11) + 0.99^2*0.9^3*0.5*0.33*0.25*(12+0.99*13*0-12)
20+
println("test")
21+
@test ops == [op1, op2]
22+
end

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
2-
using ReinforcementLearning
2+
using ReinforcementLearningZoo
33

4-
@testset "ReinforcementLearning" begin
4+
@testset "ReinforcementLearningZoo" begin
5+
include("operators.jl")
56
end

0 commit comments

Comments
 (0)