Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Support rlintro #119

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions src/environments/examples/RandomWalk1D.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
export RandomWalk1D

"""
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2)
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2, actions=[-1,1])

An agent is placed at the `start_pos` and can either move `:left` or `:right`.
The game terminates when the agent reaches either end and receives a reward
correspondingly.
An agent is placed at the `start_pos` and can move left or right (stride is
defined in actions). The game terminates when the agent reaches either end and
receives a reward correspondingly.

Compared to the [`MultiArmBanditsEnv`](@ref):

Expand All @@ -16,24 +16,15 @@ Compared to the [`MultiArmBanditsEnv`](@ref):
Base.@kwdef mutable struct RandomWalk1D <: AbstractEnv
rewards::Pair{Float64,Float64} = -1.0 => 1.0
N::Int = 7
actions::Vector{Int} = [-1, 1]
start_pos::Int = (N + 1) ÷ 2
pos::Int = start_pos
end

const ACTIONS_OF_RANDOMWALK1D = (:left, :right)
RLBase.action_space(env::RandomWalk1D) = Base.OneTo(length(env.actions))

RLBase.action_space(::RandomWalk1D) = Base.OneTo(length(ACTIONS_OF_RANDOMWALK1D))

(env::RandomWalk1D)(action::Int) = env(ACTIONS_OF_RANDOMWALK1D[action])

function (env::RandomWalk1D)(action::Symbol)
if action == :left
env.pos = max(env.pos - 1, 1)
elseif action == :right
env.pos = min(env.pos + 1, env.N)
else
@error "invalid action: $action"
end
function (env::RandomWalk1D)(action)
env.pos = max(min(env.pos + env.actions[action], env.N), 1)
end

RLBase.state(env::RandomWalk1D) = env.pos
Expand Down
22 changes: 21 additions & 1 deletion src/environments/examples/TicTacToeEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) =
RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
Base.OneTo(length(get_tic_tac_toe_state_info()))

RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, p) = WorldSpace{String}()

function RLBase.state(env::TicTacToeEnv, ::Observation{String}, p)
buff = IOBuffer()
for i in 1:3
for j in 1:3
if env.board[i, j, 1]
x = '.'
elseif env.board[i, j, 2]
x = 'x'
else
x = 'o'
end
print(buff, x)
end
print(buff, '\n')
end
String(take!(buff))
end

RLBase.is_terminated(env::TicTacToeEnv) = get_tic_tac_toe_state_info()[env].is_terminated

function RLBase.reward(env::TicTacToeEnv, player)
Expand Down Expand Up @@ -150,7 +170,7 @@ RLBase.NumAgentStyle(::TicTacToeEnv) = MultiAgent(2)
RLBase.DynamicStyle(::TicTacToeEnv) = SEQUENTIAL
RLBase.ActionStyle(::TicTacToeEnv) = FULL_ACTION_SET
RLBase.InformationStyle(::TicTacToeEnv) = PERFECT_INFORMATION
RLBase.StateStyle(::TicTacToeEnv) = (Observation{Int}(), Observation{BitArray{3}}())
RLBase.StateStyle(::TicTacToeEnv) = (Observation{String}(), Observation{Int}(), Observation{BitArray{3}}())
RLBase.RewardStyle(::TicTacToeEnv) = TERMINAL_REWARD
RLBase.UtilityStyle(::TicTacToeEnv) = ZERO_SUM
RLBase.ChanceStyle(::TicTacToeEnv) = DETERMINISTIC
7 changes: 6 additions & 1 deletion src/environments/wrappers/DefaultStateStyle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)
RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f ∉ (:DefaultStateStyle,)
if f ∉ (:DefaultStateStyle, :state, :state_space)
@eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
Expand All @@ -23,5 +23,10 @@ end
(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...)

RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) = state(env.env, ss, p)

RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)

RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
state_space(env.env, ss, p)