@@ -45,24 +45,3 @@ function retrace_operator(qnetwork, policy, batch, γ, λ)
45
45
46
46
return ops # batchsize vector of operator
47
47
end
48
-
49
- batch = (state= [[1 2 3 ], [10 11 12 ]],
50
- action = [[1 2 3 ],[10 11 12 ]],
51
- log_prob = [log .([0.2 ,0.2 ,0.2 ]), log .([0.1 ,0.1 ,0.1 ])],
52
- reward = [[1f0 ,2f0 ,3f0 ],[10f0 ,11f0 ,12f0 ]],
53
- terminal= [[0 ,0 ,1 ], [0 ,0 ,0 ]],
54
- next_state = [[2 3 4 ],[11 12 13 ]])
55
-
56
- current_log_probs = [log .([0.1 / 2 ,0.1 / 3 ,0.1 / 4 ]) for i in 1 : 2 ]
57
- policy (x; is_sampling = true , is_return_log_prob = false ) = identity (x)
58
- policy (s,a) = current_log_probs[2 ]
59
- qnetwork (x, args... ) = x[1 , :]
60
- λ, γ = 0.9 , 0.99
61
- target (Qp) = Qp
62
- send_to_device (x) = identity
63
- retrace_operator (qnetwork, policy, batch, γ, λ)
64
-
65
-
66
- # calculer à la main pour batch[2]
67
- 1 * 0.9 * 1 / 4 * (1 + 0.99 * 2 - 1 ) + 0.99 * 0.9 ^ 2 * 1 / 4 * 1 / 6 * (2 + 0.99 * 3 - 2 ) + 0.99 ^ 2 * 0.9 ^ 3 * 1 / 4 * 1 / 6 * 1 / 8 * (3 + 0.99 * 4 * 1 - 3 )
68
- 1 * 0.9 * 0.5 * (10 + 0.99 * 11 - 10 ) + 0.99 * 0.9 ^ 2 * 0.5 * 1 / 3 * (11 + 0.99 * 12 - 11 ) + 0.99 ^ 2 * 0.9 ^ 3 * 0.5 * 0.33 * 0.25 * (12 + 0.99 * 13 * 0 - 12 )
0 commit comments