Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 6ed2018

Browse files
Format .jl files (#32)
1 parent 1d60e86 commit 6ed2018

File tree

9 files changed

+120
-123
lines changed

9 files changed

+120
-123
lines changed

src/environments/atari.jl

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ using ArcadeLearningEnvironment, GR, Random
22

33
export AtariEnv
44

5-
mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv
5+
mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv
66
ale::Ptr{Nothing}
7-
screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling
7+
screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling
88
actions::Vector{Int64}
99
action_space::DiscreteSpace{UnitRange{Int}}
10-
observation_space::MultiDiscreteSpace{Array{UInt8, N}}
10+
observation_space::MultiDiscreteSpace{Array{UInt8,N}}
1111
noopmax::Int
1212
frame_skip::Int
1313
reward::Float32
@@ -36,24 +36,25 @@ TODO: support seed! in single/multi thread
3636
3737
See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
3838
"""
39-
function AtariEnv(
40-
;name = "pong",
41-
grayscale_obs=true,
39+
function AtariEnv(;
40+
name = "pong",
41+
grayscale_obs = true,
4242
noop_max = 30,
4343
frame_skip = 4,
44-
terminal_on_life_loss=false,
45-
repeat_action_probability=0.,
46-
color_averaging=false,
47-
max_num_frames_per_episode=0,
48-
full_action_space=false,
49-
seed=nothing
44+
terminal_on_life_loss = false,
45+
repeat_action_probability = 0.0,
46+
color_averaging = false,
47+
max_num_frames_per_episode = 0,
48+
full_action_space = false,
49+
seed = nothing,
5050
)
5151
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
52-
name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))
52+
name in getROMList() ||
53+
throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))
5354

5455
if isnothing(seed)
5556
seed = (MersenneTwister(), 0)
56-
elseif seed isa Tuple{Int, Int}
57+
elseif seed isa Tuple{Int,Int}
5758
seed = (MersenneTwister(seed[1]), seed[2])
5859
else
5960
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
@@ -67,20 +68,19 @@ function AtariEnv(
6768
setBool(ale, "color_averaging", color_averaging)
6869
loadROM(ale, name)
6970

70-
observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
71+
observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) :
72+
(3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
7173
observation_space = MultiDiscreteSpace(
7274
fill(typemin(Cuchar), observation_size),
7375
fill(typemax(Cuchar), observation_size),
7476
)
7577

7678
actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
7779
action_space = DiscreteSpace(length(actions))
78-
screens = (
79-
fill(typemin(Cuchar), observation_size),
80-
fill(typemin(Cuchar), observation_size),
81-
)
80+
screens =
81+
(fill(typemin(Cuchar), observation_size), fill(typemin(Cuchar), observation_size))
8282

83-
AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}(
83+
AtariEnv{grayscale_obs,terminal_on_life_loss,grayscale_obs ? 2 : 3,typeof(seed[1])}(
8484
ale,
8585
screens,
8686
actions,
@@ -90,14 +90,18 @@ function AtariEnv(
9090
frame_skip,
9191
0.0f0,
9292
lives(ale),
93-
seed[1]
93+
seed[1],
9494
)
9595
end
9696

97-
update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
98-
update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))
97+
update_screen!(env::AtariEnv{true}, screen) =
98+
ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
99+
update_screen!(env::AtariEnv{false}, screen) =
100+
ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))
99101

100-
function (env::AtariEnv{is_gray_scale, is_terminal_on_life_loss})(a) where {is_gray_scale, is_terminal_on_life_loss}
102+
function (env::AtariEnv{is_gray_scale,is_terminal_on_life_loss})(
103+
a,
104+
) where {is_gray_scale,is_terminal_on_life_loss}
101105
r = 0.0f0
102106

103107
for i in 1:env.frame_skip
@@ -118,14 +122,15 @@ function (env::AtariEnv{is_gray_scale, is_terminal_on_life_loss})(a) where {is_g
118122
nothing
119123
end
120124

121-
is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
122-
is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale)
125+
is_terminal(env::AtariEnv{<:Any,true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
126+
is_terminal(env::AtariEnv{<:Any,false}) = game_over(env.ale)
123127

124-
RLBase.observe(env::AtariEnv) = (reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
128+
RLBase.observe(env::AtariEnv) =
129+
(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
125130

126131
function RLBase.reset!(env::AtariEnv)
127132
reset_game(env.ale)
128-
for _ = 1:rand(env.seed, 0:env.noopmax)
133+
for _ in 1:rand(env.seed, 0:env.noopmax)
129134
act(env.ale, Int32(0))
130135
end
131136
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
@@ -145,7 +150,7 @@ function imshowcolor(x::Array{UInt8,1}, dims)
145150
setwindow(0, 1, 0, 1)
146151
y = (zeros(UInt32, dims...) .+ 0xff) .<< 24
147152
img = UInt32.(x)
148-
@simd for i = 1:length(y)
153+
@simd for i in 1:length(y)
149154
@inbounds y[i] += img[3*(i-1)+1] + img[3*(i-1)+2] << 8 + img[3*i] << 16
150155
end
151156
drawimage(0, 1, 0, 1, dims..., y)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
include("cartpole.jl")
22
include("mountain_car.jl")
3-
include("pendulum.jl")
3+
include("pendulum.jl")

src/environments/classic_control/mountain_car.jl

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,16 @@ struct MountainCarEnvParams{T}
1313
max_steps::Int
1414
end
1515

16-
function MountainCarEnvParams(
17-
;
16+
function MountainCarEnvParams(;
1817
T = Float64,
1918
min_pos = -1.2,
20-
max_pos = .6,
21-
max_speed = .07,
22-
goal_pos = .5,
19+
max_pos = 0.6,
20+
max_speed = 0.07,
21+
goal_pos = 0.5,
2322
max_steps = 200,
24-
goal_velocity = .0,
25-
power = .001,
26-
gravity = .0025,
23+
goal_velocity = 0.0,
24+
power = 0.001,
25+
gravity = 0.0025,
2726
)
2827
MountainCarEnvParams{T}(
2928
min_pos,
@@ -48,21 +47,15 @@ mutable struct MountainCarEnv{A,T,R<:AbstractRNG} <: AbstractEnv
4847
rng::R
4948
end
5049

51-
function MountainCarEnv(
52-
;
53-
T = Float64,
54-
continuous = false,
55-
seed = nothing,
56-
kwargs...,
57-
)
50+
function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwargs...)
5851
if continuous
59-
params = MountainCarEnvParams(; goal_pos = .45, power = .0015, T = T, kwargs...)
52+
params = MountainCarEnvParams(; goal_pos = 0.45, power = 0.0015, T = T, kwargs...)
6053
else
6154
params = MountainCarEnvParams(; kwargs...)
6255
end
6356
env = MountainCarEnv(
6457
params,
65-
continuous ? ContinuousSpace(-T(1.), T(1.)) : DiscreteSpace(3),
58+
continuous ? ContinuousSpace(-T(1.0), T(1.0)) : DiscreteSpace(3),
6659
MultiContinuousSpace(
6760
[params.min_pos, -params.max_speed],
6861
[params.max_pos, params.max_speed],
@@ -81,15 +74,12 @@ ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwar
8174

8275
Random.seed!(env::MountainCarEnv, seed) = Random.seed!(env.rng, seed)
8376

84-
RLBase.observe(env::MountainCarEnv) = (
85-
reward = env.done ? 0. : -1.,
86-
terminal = env.done,
87-
state = env.state
88-
)
77+
RLBase.observe(env::MountainCarEnv) =
78+
(reward = env.done ? 0.0 : -1.0, terminal = env.done, state = env.state)
8979

9080
function RLBase.reset!(env::MountainCarEnv{A,T}) where {A,T}
91-
env.state[1] = .2 * rand(env.rng, T) - .6
92-
env.state[2] = 0.
81+
env.state[1] = 0.2 * rand(env.rng, T) - 0.6
82+
env.state[2] = 0.0
9383
env.done = false
9484
env.t = 0
9585
nothing
@@ -109,8 +99,9 @@ function _interact!(env::MountainCarEnv, force)
10999
if x == env.params.min_pos && v < 0
110100
v = 0
111101
end
112-
env.done = x >= env.params.goal_pos && v >= env.params.goal_velocity ||
113-
env.t >= env.params.max_steps
102+
env.done =
103+
x >= env.params.goal_pos && v >= env.params.goal_velocity ||
104+
env.t >= env.params.max_steps
114105
env.state[1] = x
115106
env.state[2] = v
116107
nothing
@@ -126,10 +117,10 @@ function render(env::MountainCarEnv)
126117
clearws()
127118
setviewport(0, 1, 0, 1)
128119
setwindow(
129-
env.params.min_pos - .1,
130-
env.params.max_pos + .2,
120+
env.params.min_pos - 0.1,
121+
env.params.max_pos + 0.2,
131122
-.1,
132-
height(env.params.max_pos) + .2,
123+
height(env.params.max_pos) + 0.2,
133124
)
134125
xs = LinRange(env.params.min_pos, env.params.max_pos, 100)
135126
ys = height.(xs)
@@ -138,13 +129,13 @@ function render(env::MountainCarEnv)
138129
θ = cos(3 * x)
139130
carwidth = 0.05
140131
carheight = carwidth / 2
141-
clearance = .2 * carheight
132+
clearance = 0.2 * carheight
142133
xs = [-carwidth / 2, -carwidth / 2, carwidth / 2, carwidth / 2]
143134
ys = [0, carheight, carheight, 0]
144135
ys .+= clearance
145136
xs, ys = rotate(xs, ys, θ)
146137
xs, ys = translate(xs, ys, [x, height(x)])
147138
fillarea(xs, ys)
148-
plotendofepisode(env.params.max_pos + .1, 0, d)
139+
plotendofepisode(env.params.max_pos + 0.1, 0, d)
149140
updatews()
150141
end

src/environments/classic_control/pendulum.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,21 @@ mutable struct PendulumEnv{T,R<:AbstractRNG} <: AbstractEnv
2323
reward::T
2424
end
2525

26-
function PendulumEnv(
27-
;
26+
function PendulumEnv(;
2827
T = Float64,
2928
max_speed = T(8),
3029
max_torque = T(2),
3130
g = T(10),
3231
m = T(1),
3332
l = T(1),
34-
dt = T(.05),
33+
dt = T(0.05),
3534
max_steps = 200,
36-
seed = nothing
35+
seed = nothing,
3736
)
3837
high = T.([1, 1, max_speed])
3938
env = PendulumEnv(
4039
PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps),
41-
ContinuousSpace(-2., 2.),
40+
ContinuousSpace(-2.0, 2.0),
4241
MultiContinuousSpace(-high, high),
4342
zeros(T, 2),
4443
false,
@@ -55,11 +54,8 @@ Random.seed!(env::PendulumEnv, seed) = Random.seed!(env.rng, seed)
5554
pendulum_observation(s) = [cos(s[1]), sin(s[1]), s[2]]
5655
angle_normalize(x) = ((x + pi) % (2 * pi)) - pi
5756

58-
RLBase.observe(env::PendulumEnv) = (
59-
reward = env.reward,
60-
state = pendulum_observation(env.state),
61-
terminal = env.done,
62-
)
57+
RLBase.observe(env::PendulumEnv) =
58+
(reward = env.reward, state = pendulum_observation(env.state), terminal = env.done)
6359

6460
function RLBase.reset!(env::PendulumEnv{T}) where {T}
6561
env.state[:] = 2 * rand(env.rng, T, 2) .- 1
@@ -73,15 +69,18 @@ function (env::PendulumEnv)(a)
7369
env.t += 1
7470
th, thdot = env.state
7571
a = clamp(a, -env.params.max_torque, env.params.max_torque)
76-
costs = angle_normalize(th)^2 + .1 * thdot^2 + .001 * a^2
77-
newthdot = thdot +
78-
(-3 * env.params.g / (2 * env.params.l) * sin(th + pi) +
79-
3 * a / (env.params.m * env.params.l^2)) * env.params.dt
72+
costs = angle_normalize(th)^2 + 0.1 * thdot^2 + 0.001 * a^2
73+
newthdot =
74+
thdot +
75+
(
76+
-3 * env.params.g / (2 * env.params.l) * sin(th + pi) +
77+
3 * a / (env.params.m * env.params.l^2)
78+
) * env.params.dt
8079
th += newthdot * env.params.dt
8180
newthdot = clamp(newthdot, -env.params.max_speed, env.params.max_speed)
8281
env.state[1] = th
8382
env.state[2] = newthdot
8483
env.done = env.t >= env.params.max_steps
8584
env.reward = -costs
8685
nothing
87-
end
86+
end

src/environments/gym.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@ function GymEnv(name::String)
1616
pyenv = gym.make(name)
1717
obs_space = convert(AbstractSpace, pyenv.observation_space)
1818
act_space = convert(AbstractSpace, pyenv.action_space)
19-
obs_type =
20-
if obs_space isa Union{MultiContinuousSpace,MultiDiscreteSpace}
21-
PyArray
22-
elseif obs_space isa ContinuousSpace
23-
Float64
24-
elseif obs_space isa DiscreteSpace
25-
Int
26-
elseif obs_space isa TupleSpace
27-
PyVector
28-
elseif obs_space isa DictSpace
29-
PyDict
30-
else
31-
error("don't know how to get the observation type from observation space of $obs_space")
32-
end
19+
obs_type = if obs_space isa Union{MultiContinuousSpace,MultiDiscreteSpace}
20+
PyArray
21+
elseif obs_space isa ContinuousSpace
22+
Float64
23+
elseif obs_space isa DiscreteSpace
24+
Int
25+
elseif obs_space isa TupleSpace
26+
PyVector
27+
elseif obs_space isa DictSpace
28+
PyDict
29+
else
30+
error("don't know how to get the observation type from observation space of $obs_space")
31+
end
3332
env = GymEnv{obs_type,typeof(act_space),typeof(obs_space)}(
3433
pyenv,
3534
obs_space,
@@ -57,7 +56,7 @@ function RLBase.observe(env::GymEnv{T}) where {T}
5756
else
5857
# env has just been reseted
5958
(
60-
reward = 0., # dummy
59+
reward = 0.0, # dummy
6160
terminal = false,
6261
state = convert(T, env.state),
6362
)
@@ -107,7 +106,8 @@ function list_gym_env_names(;
107106
"gym.envs.robotics",
108107
"gym.envs.toy_text",
109108
"gym.envs.unittest",
110-
])
109+
],
110+
)
111111
gym = pyimport("gym")
112112
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
113-
end
113+
end

0 commit comments

Comments
 (0)