Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 6e1cb9a

Browse files
committed
extend RandomWalk1D
1 parent b8e3541 commit 6e1cb9a

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

src/environments/examples/RandomWalk1D.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
export RandomWalk1D
22

33
"""
4-
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2)
4+
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2, actions=[-1,1])
55
6-
An agent is placed at the `start_pos` and can either move `:left` or `:right`.
7-
The game terminates when the agent reaches either end and receives a reward
8-
correspondingly.
6+
An agent is placed at the `start_pos` and can move left or right (stride is
7+
defined in actions). The game terminates when the agent reaches either end and
8+
receives a reward correspondingly.
99
1010
Compared to the [`MultiArmBanditsEnv`](@ref):
1111
@@ -16,24 +16,15 @@ Compared to the [`MultiArmBanditsEnv`](@ref):
1616
Base.@kwdef mutable struct RandomWalk1D <: AbstractEnv
1717
rewards::Pair{Float64,Float64} = -1.0 => 1.0
1818
N::Int = 7
19+
actions::Vector{Int} = [-1, 1]
1920
start_pos::Int = (N + 1) ÷ 2
2021
pos::Int = start_pos
2122
end
2223

23-
const ACTIONS_OF_RANDOMWALK1D = (:left, :right)
24+
RLBase.action_space(env::RandomWalk1D) = Base.OneTo(length(env.actions))
2425

25-
RLBase.action_space(::RandomWalk1D) = Base.OneTo(length(ACTIONS_OF_RANDOMWALK1D))
26-
27-
(env::RandomWalk1D)(action::Int) = env(ACTIONS_OF_RANDOMWALK1D[action])
28-
29-
function (env::RandomWalk1D)(action::Symbol)
30-
if action == :left
31-
env.pos = max(env.pos - 1, 1)
32-
elseif action == :right
33-
env.pos = min(env.pos + 1, env.N)
34-
else
35-
@error "invalid action: $action"
36-
end
26+
function (env::RandomWalk1D)(action)
27+
env.pos = max(min(env.pos + env.actions[action], env.N), 1)
3728
end
3829

3930
RLBase.state(env::RandomWalk1D) = env.pos

0 commit comments

Comments
 (0)