Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 9220f4b

Browse files
authored
Support rlintro (#119)
* support string representation for TicTacToeEnv * fix bugs with DefaultStateStyleEnv * extend RandomWalk1D
1 parent 4e3eb9e commit 9220f4b

File tree

3 files changed

+35
-19
lines changed

3 files changed

+35
-19
lines changed

src/environments/examples/RandomWalk1D.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
export RandomWalk1D
22

33
"""
4-
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2)
4+
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2, actions=[-1,1])
55
6-
An agent is placed at the `start_pos` and can either move `:left` or `:right`.
7-
The game terminates when the agent reaches either end and receives a reward
8-
correspondingly.
6+
An agent is placed at the `start_pos` and can move left or right (stride is
7+
defined in actions). The game terminates when the agent reaches either end and
8+
receives a reward correspondingly.
99
1010
Compared to the [`MultiArmBanditsEnv`](@ref):
1111
@@ -16,24 +16,15 @@ Compared to the [`MultiArmBanditsEnv`](@ref):
1616
Base.@kwdef mutable struct RandomWalk1D <: AbstractEnv
1717
rewards::Pair{Float64,Float64} = -1.0 => 1.0
1818
N::Int = 7
19+
actions::Vector{Int} = [-1, 1]
1920
start_pos::Int = (N + 1) ÷ 2
2021
pos::Int = start_pos
2122
end
2223

23-
const ACTIONS_OF_RANDOMWALK1D = (:left, :right)
24+
RLBase.action_space(env::RandomWalk1D) = Base.OneTo(length(env.actions))
2425

25-
RLBase.action_space(::RandomWalk1D) = Base.OneTo(length(ACTIONS_OF_RANDOMWALK1D))
26-
27-
(env::RandomWalk1D)(action::Int) = env(ACTIONS_OF_RANDOMWALK1D[action])
28-
29-
function (env::RandomWalk1D)(action::Symbol)
30-
if action == :left
31-
env.pos = max(env.pos - 1, 1)
32-
elseif action == :right
33-
env.pos = min(env.pos + 1, env.N)
34-
else
35-
@error "invalid action: $action"
36-
end
26+
function (env::RandomWalk1D)(action)
27+
env.pos = max(min(env.pos + env.actions[action], env.N), 1)
3728
end
3829

3930
RLBase.state(env::RandomWalk1D) = env.pos

src/environments/examples/TicTacToeEnv.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,26 @@ RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) =
8282
RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
8383
Base.OneTo(length(get_tic_tac_toe_state_info()))
8484

85+
RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, p) = WorldSpace{String}()
86+
87+
function RLBase.state(env::TicTacToeEnv, ::Observation{String}, p)
88+
buff = IOBuffer()
89+
for i in 1:3
90+
for j in 1:3
91+
if env.board[i, j, 1]
92+
x = '.'
93+
elseif env.board[i, j, 2]
94+
x = 'x'
95+
else
96+
x = 'o'
97+
end
98+
print(buff, x)
99+
end
100+
print(buff, '\n')
101+
end
102+
String(take!(buff))
103+
end
104+
85105
RLBase.is_terminated(env::TicTacToeEnv) = get_tic_tac_toe_state_info()[env].is_terminated
86106

87107
function RLBase.reward(env::TicTacToeEnv, player)
@@ -150,7 +170,7 @@ RLBase.NumAgentStyle(::TicTacToeEnv) = MultiAgent(2)
150170
RLBase.DynamicStyle(::TicTacToeEnv) = SEQUENTIAL
151171
RLBase.ActionStyle(::TicTacToeEnv) = FULL_ACTION_SET
152172
RLBase.InformationStyle(::TicTacToeEnv) = PERFECT_INFORMATION
153-
RLBase.StateStyle(::TicTacToeEnv) = (Observation{Int}(), Observation{BitArray{3}}())
173+
RLBase.StateStyle(::TicTacToeEnv) = (Observation{String}(), Observation{Int}(), Observation{BitArray{3}}())
154174
RLBase.RewardStyle(::TicTacToeEnv) = TERMINAL_REWARD
155175
RLBase.UtilityStyle(::TicTacToeEnv) = ZERO_SUM
156176
RLBase.ChanceStyle(::TicTacToeEnv) = DETERMINISTIC

src/environments/wrappers/DefaultStateStyle.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)
1414
RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S
1515

1616
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
17-
if f (:DefaultStateStyle,)
17+
if f (:DefaultStateStyle, :state, :state_space)
1818
@eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) =
1919
$f(x.env, args...; kwargs...)
2020
end
@@ -23,5 +23,10 @@ end
2323
(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...)
2424

2525
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
26+
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) = state(env.env, ss, p)
27+
2628
RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
2729
state_space(env.env, ss)
30+
31+
RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
32+
state_space(env.env, ss, p)

0 commit comments

Comments
 (0)