|
| 1 | +export retrace |
| 2 | +#Note: the speed of this operator may be improved by batching states and actions |
| 3 | +#to make single calls to Q, instead of one per batch sample. However this operator |
| 4 | +#is no backpropagated through and its computation will typically represent a minor |
| 5 | +#fraction of the runtime of a deep RL algorithm. |
| 6 | + |
| 7 | +function retrace_operator(qnetwork, policy, batch, γ, λ) |
| 8 | + s = batch[:state] |> send_to_device(qnetwork) |
| 9 | + a = batch[:action] |> send_to_device(qnetwork) |
| 10 | + behavior_log_probs = batch[:log_prob] |> send_to_device(qnetwork) |
| 11 | + r = batch[:reward] |> send_to_device(qnetwork) |
| 12 | + t = last.(batch[:terminal]) |> send_to_device(qnetwork) |
| 13 | + ns = batch[:next_state] |> send_to_device(qnetwork) |
| 14 | + na = map(ns) do ns |
| 15 | + policy(ns, is_sampling = true, is_return_log_prob = false) |
| 16 | + end |
| 17 | + states = map(s,ns) do s, ns #concatenates all states, including the last state to compute deltas with the target Q |
| 18 | + cat(s,last(eachslice(ns, dims = ndims(ns))),dims=ndims(s)) |
| 19 | + end |
| 20 | + actions = map(a, na) do a, na |
| 21 | + cat(a,last(eachslice(na, dims = ndims(na))),dims=ndims(a)) |
| 22 | + end |
| 23 | + |
| 24 | + current_log_probs = map(s,a) do s, a |
| 25 | + policy(s,a) |
| 26 | + end |
| 27 | + |
| 28 | + traces = map(current_log_probs, behavior_log_probs) do p,m |
| 29 | + @. λ*min(1, exp(p - m)) |
| 30 | + end |
| 31 | + is_ratios = cumprod.(traces) #batchsized vector [[c1,c2,...,ct],[c1,c2,...,ct],...] |
| 32 | + |
| 33 | + Qp = target(qnetwork) |
| 34 | + |
| 35 | + δs = map(states, actions, r, t) do s, a, r, t |
| 36 | + q = vec(Qp(vcat(s,a))) |
| 37 | + q[end] *= t |
| 38 | + r .+ q[2:end].*γ .- q[1:end-1] |
| 39 | + end |
| 40 | + |
| 41 | + ops = map(is_ratios, δs) do ratios, deltas |
| 42 | + γs = γ .^ (0:(length(deltas)-1)) |
| 43 | + sum(γs .* ratios .* deltas) |
| 44 | + end |
| 45 | + |
| 46 | + return ops #batchsize vector of operator |
| 47 | +end |
| 48 | + |
| 49 | +batch = (state= [[1 2 3], [10 11 12]], |
| 50 | + action = [[1 2 3],[10 11 12]], |
| 51 | + log_prob = [log.([0.2,0.2,0.2]), log.([0.1,0.1,0.1])], |
| 52 | + reward = [[1f0,2f0,3f0],[10f0,11f0,12f0]], |
| 53 | + terminal= [[0,0,1], [0,0,0]], |
| 54 | + next_state = [[2 3 4],[11 12 13]]) |
| 55 | + |
| 56 | +current_log_probs = [log.([0.1/2,0.1/3,0.1/4]) for i in 1:2] |
| 57 | +policy(x; is_sampling = true, is_return_log_prob = false) = identity(x) |
| 58 | +policy(s,a) = current_log_probs[2] |
| 59 | +qnetwork(x, args...) = x[1, :] |
| 60 | +λ, γ = 0.9, 0.99 |
| 61 | +target(Qp) = Qp |
| 62 | +send_to_device(x) = identity |
| 63 | +retrace_operator(qnetwork, policy, batch, γ, λ) |
| 64 | + |
| 65 | + |
| 66 | +#calculer à la main pour batch[2] |
| 67 | +1*0.9*1/4*(1+0.99*2-1) + 0.99*0.9^2*1/4*1/6*(2+0.99*3-2) + 0.99^2*0.9^3*1/4*1/6*1/8*(3+0.99*4*1-3) |
| 68 | +1*0.9*0.5*(10+0.99*11-10) + 0.99*0.9^2*0.5*1/3*(11+0.99*12-11) + 0.99^2*0.9^3*0.5*0.33*0.25*(12+0.99*13*0-12) |
0 commit comments