Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Automatic JuliaFormatter.jl run #32

Merged
merged 1 commit into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ using ArcadeLearningEnvironment, GR, Random

export AtariEnv

mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv
mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv
ale::Ptr{Nothing}
screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling
screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling
actions::Vector{Int64}
action_space::DiscreteSpace{UnitRange{Int}}
observation_space::MultiDiscreteSpace{Array{UInt8, N}}
observation_space::MultiDiscreteSpace{Array{UInt8,N}}
noopmax::Int
frame_skip::Int
reward::Float32
Expand Down Expand Up @@ -36,24 +36,25 @@ TODO: support seed! in single/multi thread

See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
"""
function AtariEnv(
;name = "pong",
grayscale_obs=true,
function AtariEnv(;
name = "pong",
grayscale_obs = true,
noop_max = 30,
frame_skip = 4,
terminal_on_life_loss=false,
repeat_action_probability=0.,
color_averaging=false,
max_num_frames_per_episode=0,
full_action_space=false,
seed=nothing
terminal_on_life_loss = false,
repeat_action_probability = 0.0,
color_averaging = false,
max_num_frames_per_episode = 0,
full_action_space = false,
seed = nothing,
)
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))
name in getROMList() ||
throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))

if isnothing(seed)
seed = (MersenneTwister(), 0)
elseif seed isa Tuple{Int, Int}
elseif seed isa Tuple{Int,Int}
seed = (MersenneTwister(seed[1]), seed[2])
else
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
Expand All @@ -67,20 +68,19 @@ function AtariEnv(
setBool(ale, "color_averaging", color_averaging)
loadROM(ale, name)

observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) :
(3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
observation_space = MultiDiscreteSpace(
fill(typemin(Cuchar), observation_size),
fill(typemax(Cuchar), observation_size),
)

actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
action_space = DiscreteSpace(length(actions))
screens = (
fill(typemin(Cuchar), observation_size),
fill(typemin(Cuchar), observation_size),
)
screens =
(fill(typemin(Cuchar), observation_size), fill(typemin(Cuchar), observation_size))

AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}(
AtariEnv{grayscale_obs,terminal_on_life_loss,grayscale_obs ? 2 : 3,typeof(seed[1])}(
ale,
screens,
actions,
Expand All @@ -90,14 +90,18 @@ function AtariEnv(
frame_skip,
0.0f0,
lives(ale),
seed[1]
seed[1],
)
end

update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))
update_screen!(env::AtariEnv{true}, screen) =
ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
update_screen!(env::AtariEnv{false}, screen) =
ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))

function (env::AtariEnv{is_gray_scale, is_terminal_on_life_loss})(a) where {is_gray_scale, is_terminal_on_life_loss}
function (env::AtariEnv{is_gray_scale,is_terminal_on_life_loss})(
a,
) where {is_gray_scale,is_terminal_on_life_loss}
r = 0.0f0

for i in 1:env.frame_skip
Expand All @@ -118,14 +122,15 @@ function (env::AtariEnv{is_gray_scale, is_terminal_on_life_loss})(a) where {is_g
nothing
end

is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale)
is_terminal(env::AtariEnv{<:Any,true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
is_terminal(env::AtariEnv{<:Any,false}) = game_over(env.ale)

RLBase.observe(env::AtariEnv) = (reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
RLBase.observe(env::AtariEnv) =
(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])

function RLBase.reset!(env::AtariEnv)
reset_game(env.ale)
for _ = 1:rand(env.seed, 0:env.noopmax)
for _ in 1:rand(env.seed, 0:env.noopmax)
act(env.ale, Int32(0))
end
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
Expand All @@ -145,7 +150,7 @@ function imshowcolor(x::Array{UInt8,1}, dims)
setwindow(0, 1, 0, 1)
y = (zeros(UInt32, dims...) .+ 0xff) .<< 24
img = UInt32.(x)
@simd for i = 1:length(y)
@simd for i in 1:length(y)
@inbounds y[i] += img[3*(i-1)+1] + img[3*(i-1)+2] << 8 + img[3*i] << 16
end
drawimage(0, 1, 0, 1, dims..., y)
Expand Down
2 changes: 1 addition & 1 deletion src/environments/classic_control/classic_control.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include("cartpole.jl")
include("mountain_car.jl")
include("pendulum.jl")
include("pendulum.jl")
53 changes: 22 additions & 31 deletions src/environments/classic_control/mountain_car.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ struct MountainCarEnvParams{T}
max_steps::Int
end

function MountainCarEnvParams(
;
function MountainCarEnvParams(;
T = Float64,
min_pos = -1.2,
max_pos = .6,
max_speed = .07,
goal_pos = .5,
max_pos = 0.6,
max_speed = 0.07,
goal_pos = 0.5,
max_steps = 200,
goal_velocity = .0,
power = .001,
gravity = .0025,
goal_velocity = 0.0,
power = 0.001,
gravity = 0.0025,
)
MountainCarEnvParams{T}(
min_pos,
Expand All @@ -48,21 +47,15 @@ mutable struct MountainCarEnv{A,T,R<:AbstractRNG} <: AbstractEnv
rng::R
end

function MountainCarEnv(
;
T = Float64,
continuous = false,
seed = nothing,
kwargs...,
)
function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwargs...)
if continuous
params = MountainCarEnvParams(; goal_pos = .45, power = .0015, T = T, kwargs...)
params = MountainCarEnvParams(; goal_pos = 0.45, power = 0.0015, T = T, kwargs...)
else
params = MountainCarEnvParams(; kwargs...)
end
env = MountainCarEnv(
params,
continuous ? ContinuousSpace(-T(1.), T(1.)) : DiscreteSpace(3),
continuous ? ContinuousSpace(-T(1.0), T(1.0)) : DiscreteSpace(3),
MultiContinuousSpace(
[params.min_pos, -params.max_speed],
[params.max_pos, params.max_speed],
Expand All @@ -81,15 +74,12 @@ ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwar

Random.seed!(env::MountainCarEnv, seed) = Random.seed!(env.rng, seed)

RLBase.observe(env::MountainCarEnv) = (
reward = env.done ? 0. : -1.,
terminal = env.done,
state = env.state
)
RLBase.observe(env::MountainCarEnv) =
(reward = env.done ? 0.0 : -1.0, terminal = env.done, state = env.state)

function RLBase.reset!(env::MountainCarEnv{A,T}) where {A,T}
env.state[1] = .2 * rand(env.rng, T) - .6
env.state[2] = 0.
env.state[1] = 0.2 * rand(env.rng, T) - 0.6
env.state[2] = 0.0
env.done = false
env.t = 0
nothing
Expand All @@ -109,8 +99,9 @@ function _interact!(env::MountainCarEnv, force)
if x == env.params.min_pos && v < 0
v = 0
end
env.done = x >= env.params.goal_pos && v >= env.params.goal_velocity ||
env.t >= env.params.max_steps
env.done =
x >= env.params.goal_pos && v >= env.params.goal_velocity ||
env.t >= env.params.max_steps
env.state[1] = x
env.state[2] = v
nothing
Expand All @@ -126,10 +117,10 @@ function render(env::MountainCarEnv)
clearws()
setviewport(0, 1, 0, 1)
setwindow(
env.params.min_pos - .1,
env.params.max_pos + .2,
env.params.min_pos - 0.1,
env.params.max_pos + 0.2,
-.1,
height(env.params.max_pos) + .2,
height(env.params.max_pos) + 0.2,
)
xs = LinRange(env.params.min_pos, env.params.max_pos, 100)
ys = height.(xs)
Expand All @@ -138,13 +129,13 @@ function render(env::MountainCarEnv)
θ = cos(3 * x)
carwidth = 0.05
carheight = carwidth / 2
clearance = .2 * carheight
clearance = 0.2 * carheight
xs = [-carwidth / 2, -carwidth / 2, carwidth / 2, carwidth / 2]
ys = [0, carheight, carheight, 0]
ys .+= clearance
xs, ys = rotate(xs, ys, θ)
xs, ys = translate(xs, ys, [x, height(x)])
fillarea(xs, ys)
plotendofepisode(env.params.max_pos + .1, 0, d)
plotendofepisode(env.params.max_pos + 0.1, 0, d)
updatews()
end
29 changes: 14 additions & 15 deletions src/environments/classic_control/pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,21 @@ mutable struct PendulumEnv{T,R<:AbstractRNG} <: AbstractEnv
reward::T
end

function PendulumEnv(
;
function PendulumEnv(;
T = Float64,
max_speed = T(8),
max_torque = T(2),
g = T(10),
m = T(1),
l = T(1),
dt = T(.05),
dt = T(0.05),
max_steps = 200,
seed = nothing
seed = nothing,
)
high = T.([1, 1, max_speed])
env = PendulumEnv(
PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps),
ContinuousSpace(-2., 2.),
ContinuousSpace(-2.0, 2.0),
MultiContinuousSpace(-high, high),
zeros(T, 2),
false,
Expand All @@ -55,11 +54,8 @@ Random.seed!(env::PendulumEnv, seed) = Random.seed!(env.rng, seed)
pendulum_observation(s) = [cos(s[1]), sin(s[1]), s[2]]
angle_normalize(x) = ((x + pi) % (2 * pi)) - pi

RLBase.observe(env::PendulumEnv) = (
reward = env.reward,
state = pendulum_observation(env.state),
terminal = env.done,
)
RLBase.observe(env::PendulumEnv) =
(reward = env.reward, state = pendulum_observation(env.state), terminal = env.done)

function RLBase.reset!(env::PendulumEnv{T}) where {T}
env.state[:] = 2 * rand(env.rng, T, 2) .- 1
Expand All @@ -73,15 +69,18 @@ function (env::PendulumEnv)(a)
env.t += 1
th, thdot = env.state
a = clamp(a, -env.params.max_torque, env.params.max_torque)
costs = angle_normalize(th)^2 + .1 * thdot^2 + .001 * a^2
newthdot = thdot +
(-3 * env.params.g / (2 * env.params.l) * sin(th + pi) +
3 * a / (env.params.m * env.params.l^2)) * env.params.dt
costs = angle_normalize(th)^2 + 0.1 * thdot^2 + 0.001 * a^2
newthdot =
thdot +
(
-3 * env.params.g / (2 * env.params.l) * sin(th + pi) +
3 * a / (env.params.m * env.params.l^2)
) * env.params.dt
th += newthdot * env.params.dt
newthdot = clamp(newthdot, -env.params.max_speed, env.params.max_speed)
env.state[1] = th
env.state[2] = newthdot
env.done = env.t >= env.params.max_steps
env.reward = -costs
nothing
end
end
34 changes: 17 additions & 17 deletions src/environments/gym.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@ function GymEnv(name::String)
pyenv = gym.make(name)
obs_space = convert(AbstractSpace, pyenv.observation_space)
act_space = convert(AbstractSpace, pyenv.action_space)
obs_type =
if obs_space isa Union{MultiContinuousSpace,MultiDiscreteSpace}
PyArray
elseif obs_space isa ContinuousSpace
Float64
elseif obs_space isa DiscreteSpace
Int
elseif obs_space isa TupleSpace
PyVector
elseif obs_space isa DictSpace
PyDict
else
error("don't know how to get the observation type from observation space of $obs_space")
end
obs_type = if obs_space isa Union{MultiContinuousSpace,MultiDiscreteSpace}
PyArray
elseif obs_space isa ContinuousSpace
Float64
elseif obs_space isa DiscreteSpace
Int
elseif obs_space isa TupleSpace
PyVector
elseif obs_space isa DictSpace
PyDict
else
error("don't know how to get the observation type from observation space of $obs_space")
end
env = GymEnv{obs_type,typeof(act_space),typeof(obs_space)}(
pyenv,
obs_space,
Expand Down Expand Up @@ -57,7 +56,7 @@ function RLBase.observe(env::GymEnv{T}) where {T}
else
# env has just been reseted
(
reward = 0., # dummy
reward = 0.0, # dummy
terminal = false,
state = convert(T, env.state),
)
Expand Down Expand Up @@ -107,7 +106,8 @@ function list_gym_env_names(;
"gym.envs.robotics",
"gym.envs.toy_text",
"gym.envs.unittest",
])
],
)
gym = pyimport("gym")
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
end
end
Loading