Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Update RLBase to v0.8 #76

Merged
merged 11 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
GR = "0.46, 0.47, 0.48, 0.49, 0.50"
OrdinaryDiffEq = "5"
ReinforcementLearningBase = "0.7"
ReinforcementLearningBase = "0.8"
Requires = "1.0"
StatsBase = "0.32, 0.33"
julia = "1.3"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ By default, only some basic environments are installed. If you want to use some
| `AtariEnv` | [ArcadeLearningEnvironment.jl](https://github.com/JuliaReinforcementLearning/ArcadeLearningEnvironment.jl) | Tested only on Linux|
| `ViZDoomEnv` | [ViZDoom.jl](https://github.com/JuliaReinforcementLearning/ViZDoom.jl) | Currently only a basic environment is supported. (By calling `basic_ViZDoom_env()`)|
| `GymEnv` | [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) | Tested only on Linux |
| `MDPEnv`,`POMDPEnv`| [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl)| The `get_observation_space` method is undefined|
| `MDPEnv`,`POMDPEnv`| [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl)||
| `OpenSpielEnv` | [OpenSpiel.jl](https://github.com/JuliaReinforcementLearning/OpenSpiel.jl) | |

## Usage
Expand All @@ -46,7 +46,7 @@ julia> using ReinforcementLearningBase
julia> env = CartPoleEnv()
CartPoleEnv{Float64}(gravity=9.8,masscart=1.0,masspole=0.1,totalmass=1.1,halflength=0.5,polemasslength=0.05,forcemag=10.0,tau=0.02,thetathreshold=0.20943951023931953,xthreshold=2.4,max_steps=200)

julia> action_space = get_action_space(env)
julia> action_space = get_actions(env)
DiscreteSpace{UnitRange{Int64}}(1:2)

julia> while true
Expand Down
37 changes: 20 additions & 17 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using .ArcadeLearningEnvironment

This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009)

TODO: support seed! in single/multi thread

# Keywords

- `name::String="pong"`: name of the Atari environments. Use `ReinforcementLearningEnvironments.list_atari_rom_names()` to show all supported environments.
Expand All @@ -19,6 +17,8 @@ TODO: support seed! in single/multi thread
- `color_averaging::Bool=false`: whether to perform phosphor averaging or not.
- `max_num_frames_per_episode::Int=0`
- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO
- `seed::Int` is used to set the initial seed of the underlying C environment.
- `rng::AbstractRNG` is used by the this wrapper environment to initialize the number of no-op steps after [`reset!`](@ref).

See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
"""
Expand All @@ -34,21 +34,22 @@ function AtariEnv(;
max_num_frames_per_episode = 0,
full_action_space = false,
seed = nothing,
rng = Random.GLOBAL_RNG
)
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
name in getROMList() ||
throw(ArgumentError("unknown ROM name.\n\nRun `ReinforcementLearningEnvironments.list_atari_rom_names()` to see all the game names."))

if isnothing(seed)
seed = (MersenneTwister(), 0)
elseif seed isa Tuple{Int,Int}
seed = (MersenneTwister(seed[1]), seed[2])
else
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
end

ale = ALE_new()
setInt(ale, "random_seed", seed[2])
if !isnothing(seed)
if rng === Random.GLOBAL_RNG
throw(ArgumentError("you set seed to $seed but the rng is not set"))
else
setInt(ale, "random_seed", seed)
end
elseif rng !== Random.GLOBAL_RNG
throw(ArgumentError("it seems that rng is set but seed is not set yet"))
end
setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism.
setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode)
setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability))
Expand All @@ -68,7 +69,7 @@ function AtariEnv(;
(fill(typemin(Cuchar), observation_size), fill(typemin(Cuchar), observation_size))

env =
AtariEnv{grayscale_obs,terminal_on_life_loss,grayscale_obs ? 2 : 3,typeof(seed[1])}(
AtariEnv{grayscale_obs,terminal_on_life_loss,grayscale_obs ? 2 : 3,typeof(rng)}(
ale,
screens,
actions,
Expand All @@ -78,7 +79,7 @@ function AtariEnv(;
frame_skip,
0.0f0,
lives(ale),
seed[1],
rng,
)
finalizer(env) do x
ALE_del(x.ale)
Expand Down Expand Up @@ -117,12 +118,14 @@ end
is_terminal(env::AtariEnv{<:Any,true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
is_terminal(env::AtariEnv{<:Any,false}) = game_over(env.ale)

RLBase.observe(env::AtariEnv) =
(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
RLBase.get_actions(env::AtariEnv) = env.action_space
RLBase.get_reward(env::AtariEnv) = env.reward
RLBase.get_terminal(env::AtariEnv) = is_terminal(env)
RLBase.get_state(env::AtariEnv) = env.screens[1]

function RLBase.reset!(env::AtariEnv)
reset_game(env.ale)
for _ in 1:rand(env.seed, 0:env.noopmax)
for _ in 1:rand(env.rng, 0:env.noopmax)
act(env.ale, Int32(0))
end
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
Expand All @@ -148,7 +151,7 @@ function imshowcolor(x::AbstractArray{UInt8,1}, dims)
updatews()
end

function RLBase.render(env::AtariEnv)
function render(env::AtariEnv)
x = getScreenRGB(env.ale)
imshowcolor(x, (Int(getScreenWidth(env.ale)), Int(getScreenHeight(env.ale))))
end
Expand Down
10 changes: 6 additions & 4 deletions src/environments/classic_control/acrobot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function AcrobotEnv(;
g = T(9.8),
dt = T(0.2),
max_steps = 200,
seed = nothing,
rng = Random.GLOBAL_RNG,
book_or_nips = "book",
avail_torque = [T(-1.0), T(0.0), T(1.0)],
)
Expand Down Expand Up @@ -108,7 +108,7 @@ function AcrobotEnv(;
0,
false,
0,
MersenneTwister(seed),
rng,
T(0.0),
book_or_nips,
[T(-1.0), T(0.0), T(1.0)],
Expand All @@ -119,8 +119,10 @@ end

acrobot_observation(s) = [cos(s[1]), sin(s[1]), cos(s[2]), sin(s[2]), s[3], s[4]]

RLBase.observe(env::AcrobotEnv) =
(reward = env.reward, state = acrobot_observation(env.state), terminal = env.done)
RLBase.get_actions(env::AcrobotEnv) = env.action_space
RLBase.get_terminal(env::AcrobotEnv) = env.done
RLBase.get_state(env::AcrobotEnv) = acrobot_observation(env.state)
RLBase.get_reward(env::AcrobotEnv) = env.reward

function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)
Expand Down
14 changes: 8 additions & 6 deletions src/environments/classic_control/cartpole.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Base.show(io::IO, env::CartPoleEnv{T}) where {T} =
- `forcemag = T(10.0)`
- `max_steps = 200`
- 'dt = 0.02'
- `seed = nothing`
- `rng = Random.GLOBAL_RNG`
"""
function CartPoleEnv(;
T = Float64,
Expand All @@ -56,7 +56,7 @@ function CartPoleEnv(;
forcemag = 10.0,
max_steps = 200,
dt = 0.02,
seed = nothing,
rng = Random.GLOBAL_RNG,
)
params = CartPoleEnvParams{T}(
gravity,
Expand All @@ -80,7 +80,7 @@ function CartPoleEnv(;
2,
false,
0,
MersenneTwister(seed),
rng,
)
reset!(cp)
cp
Expand All @@ -96,8 +96,10 @@ function RLBase.reset!(env::CartPoleEnv{T}) where {T<:Number}
nothing
end

RLBase.observe(env::CartPoleEnv{T}) where {T} =
(reward = env.done ? zero(T) : one(T), terminal = env.done, state = env.state)
RLBase.get_actions(env::CartPoleEnv) = env.action_space
RLBase.get_reward(env::CartPoleEnv{T}) where {T} = env.done ? zero(T) : one(T)
RLBase.get_terminal(env::CartPoleEnv) = env.done
RLBase.get_state(env::CartPoleEnv) = env.state

function (env::CartPoleEnv)(a)
@assert a in (1, 2)
Expand Down Expand Up @@ -136,7 +138,7 @@ function plotendofepisode(x, y, d)
end
return nothing
end
function RLBase.render(env::CartPoleEnv)
function render(env::CartPoleEnv)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to implement Base.display

s, a, d = env.state, env.action, env.done
x, xdot, theta, thetadot = s
l = 2 * env.params.halflength
Expand Down
14 changes: 8 additions & 6 deletions src/environments/classic_control/mountain_car.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

- `T = Float64`
- `continuous = false`
- `seed = nothing`
- `rng = Random.GLOBAL_RNG`
- `min_pos = -1.2`
- `max_pos = 0.6`
- `max_speed = 0.07`
Expand All @@ -62,7 +62,7 @@ end
- `power = 0.001`
- `gravity = 0.0025`
"""
function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwargs...)
function MountainCarEnv(; T = Float64, continuous = false, rng = Random.GLOBAL_RNG, kwargs...)
if continuous
params = MountainCarEnvParams(; goal_pos = 0.45, power = 0.0015, T = T, kwargs...)
else
Expand All @@ -80,7 +80,7 @@ function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwarg
rand(action_space),
false,
0,
MersenneTwister(seed),
rng,
)
reset!(env)
env
Expand All @@ -90,8 +90,10 @@ ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwar

Random.seed!(env::MountainCarEnv, seed) = Random.seed!(env.rng, seed)

RLBase.observe(env::MountainCarEnv) =
(reward = env.done ? 0.0 : -1.0, terminal = env.done, state = env.state)
RLBase.get_actions(env::MountainCarEnv) = env.action_space
RLBase.get_reward(env::MountainCarEnv{A,T}) where {A, T} = env.done ? zero(T) : -one(T)
RLBase.get_terminal(env::MountainCarEnv) = env.done
RLBase.get_state(env::MountainCarEnv) = env.state

function RLBase.reset!(env::MountainCarEnv{A,T}) where {A,T}
env.state[1] = 0.2 * rand(env.rng, T) - 0.6
Expand Down Expand Up @@ -135,7 +137,7 @@ end
height(xs) = sin(3 * xs) * 0.45 + 0.55
rotate(xs, ys, θ) = xs * cos(θ) - ys * sin(θ), ys * cos(θ) + xs * sin(θ)
translate(xs, ys, t) = xs .+ t[1], ys .+ t[2]
function RLBase.render(env::MountainCarEnv)
function render(env::MountainCarEnv)
s = env.state
d = env.done
clearws()
Expand Down
12 changes: 7 additions & 5 deletions src/environments/classic_control/pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
- `max_steps = 200`
- `continuous::Bool = true`
- `n_actions::Int = 3`
- `seed = nothing`
- `rng = Random.GLOBAL_RNG`
"""
function PendulumEnv(;
T = Float64,
Expand All @@ -51,7 +51,7 @@ function PendulumEnv(;
max_steps = 200,
continuous::Bool = true,
n_actions::Int = 3,
seed = nothing,
rng = Random.GLOBAL_RNG
)
high = T.([1, 1, max_speed])
action_space = continuous ? ContinuousSpace(-2.0, 2.0) : DiscreteSpace(n_actions)
Expand All @@ -62,7 +62,7 @@ function PendulumEnv(;
zeros(T, 2),
false,
0,
MersenneTwister(seed),
rng,
zero(T),
n_actions,
rand(action_space),
Expand All @@ -76,8 +76,10 @@ Random.seed!(env::PendulumEnv, seed) = Random.seed!(env.rng, seed)
pendulum_observation(s) = [cos(s[1]), sin(s[1]), s[2]]
angle_normalize(x) = Base.mod((x + Base.π), (2 * Base.π)) - Base.π

RLBase.observe(env::PendulumEnv) =
(reward = env.reward, state = pendulum_observation(env.state), terminal = env.done)
RLBase.get_actions(env::PendulumEnv) = env.action_space
RLBase.get_reward(env::PendulumEnv) = env.reward
RLBase.get_terminal(env::PendulumEnv) = env.done
RLBase.get_state(env::PendulumEnv) = pendulum_observation(env.state)

function RLBase.reset!(env::PendulumEnv{A,T}) where {A,T}
env.state[1] = 2 * π * (rand(env.rng, T) .- 1)
Expand Down
33 changes: 24 additions & 9 deletions src/environments/gym.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,36 @@ function RLBase.reset!(env::GymEnv)
nothing
end

function RLBase.observe(env::GymEnv{T}) where {T}
RLBase.get_actions(env::GymEnv) = env.action_space

function RLBase.get_reward(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type)
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
(reward = reward, terminal = isdone, state = obs)
reward
else
# env has just been reseted
(
reward = 0.0, # dummy
terminal = false,
state = convert(T, env.state),
)
0.0
end
end

function RLBase.get_terminal(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type)
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
isdone
else
false
end
end

function RLBase.get_state(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type)
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
obs
else
state = convert(T, env.state)
end
end

RLBase.render(env::GymEnv) = env.pyenv.render()
render(env::GymEnv) = env.pyenv.render()

###
### utils
Expand Down
Loading