1
1
export RandomWalk1D
2
2
3
3
"""
4
- RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2)
4
+ RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2, actions=[-1,1] )
5
5
6
- An agent is placed at the `start_pos` and can either move `: left` or `: right`.
7
- The game terminates when the agent reaches either end and receives a reward
8
- correspondingly.
6
+ An agent is placed at the `start_pos` and can move left or right (stride is
7
+ defined in actions). The game terminates when the agent reaches either end and
8
+ receives a reward correspondingly.
9
9
10
10
Compared to the [`MultiArmBanditsEnv`](@ref):
11
11
@@ -16,24 +16,15 @@ Compared to the [`MultiArmBanditsEnv`](@ref):
16
16
Base. @kwdef mutable struct RandomWalk1D <: AbstractEnv
17
17
rewards:: Pair{Float64,Float64} = - 1.0 => 1.0
18
18
N:: Int = 7
19
+ actions:: Vector{Int} = [- 1 , 1 ]
19
20
start_pos:: Int = (N + 1 ) ÷ 2
20
21
pos:: Int = start_pos
21
22
end
22
23
23
- const ACTIONS_OF_RANDOMWALK1D = ( :left , :right )
24
+ RLBase . action_space (env :: RandomWalk1D ) = Base . OneTo ( length (env . actions) )
24
25
25
- RLBase. action_space (:: RandomWalk1D ) = Base. OneTo (length (ACTIONS_OF_RANDOMWALK1D))
26
-
27
- (env:: RandomWalk1D )(action:: Int ) = env (ACTIONS_OF_RANDOMWALK1D[action])
28
-
29
- function (env:: RandomWalk1D )(action:: Symbol )
30
- if action == :left
31
- env. pos = max (env. pos - 1 , 1 )
32
- elseif action == :right
33
- env. pos = min (env. pos + 1 , env. N)
34
- else
35
- @error " invalid action: $action "
36
- end
26
+ function (env:: RandomWalk1D )(action)
27
+ env. pos = max (min (env. pos + env. actions[action], env. N), 1 )
37
28
end
38
29
39
30
RLBase. state (env:: RandomWalk1D ) = env. pos
0 commit comments