Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 8c21b67

Browse files
authored
update atari env to use max-pooling (#22)
* update atari env * add test cases for atari environments * add doc
1 parent 033cf4c commit 8c21b67

File tree

5 files changed

+190
-57
lines changed

5 files changed

+190
-57
lines changed

src/abstractenv.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ function action_space end
2020
function observation_space end
2121
function render end
2222

23+
"""
24+
Observation(;reward, terminal, state, meta...)
25+
26+
The observation of an environment from the perspective of an agent.
27+
28+
# Keywords & Fields
29+
30+
- `reward`: the reward of an agent
31+
- `terminal`: indicates that if the environment is terminated or not.
32+
- `state`: the current state of the environment from the perspective of an agent
33+
- `meta`: some other information, like `legal_actions`...
34+
35+
!!! note
36+
The `reward` and `terminal` of the first observation before interacting with an environment may not be valid.
37+
"""
2338
struct Observation{R,T,S,M<:NamedTuple}
2439
reward::R
2540
terminal::T

src/environments/atari.jl

Lines changed: 98 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,136 @@
1-
using ArcadeLearningEnvironment, GR
1+
using ArcadeLearningEnvironment, GR, Random
22

33
export AtariEnv
44

5-
mutable struct AtariEnv{To,F} <: AbstractEnv
5+
mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv
66
ale::Ptr{Nothing}
7-
screen::Array{UInt8,1}
8-
getscreen!::F
9-
actions::Array{Int64,1}
7+
screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling
8+
actions::Vector{Int64}
109
action_space::DiscreteSpace{Int}
11-
observation_space::To
10+
observation_space::MultiDiscreteSpace{UInt8, N}
1211
noopmax::Int
12+
frame_skip::Int
1313
reward::Float32
14+
lives::Int
15+
seed::S
1416
end
1517

1618
"""
17-
AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20,
18-
color_averaging = true, repeat_action_probability = 0.)
19-
Returns an AtariEnv that can be used in an RLSetup of the
20-
[ReinforcementLearning](https://github.com/jbrea/ReinforcementLearning.jl)
21-
package. Check the deps/roms folder of the ArcadeLearningEnvironment package to
22-
see all available `name`s.
19+
AtariEnv(;kwargs...)
20+
21+
This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009)
22+
23+
TODO: support seed! in single/multi thread
24+
25+
# Keywords
26+
27+
- `name::String="pong"`: name of the Atari environments. Use `getROMList` to show all supported environments.
28+
- `grayscale_obs::Bool=true`:if `true`, then gray scale observation is returned, otherwise, RGB observation is returned.
29+
- `noop_max::Int=30`: max number of no-ops.
30+
- `frame_skip::Int=4`: the frequency at which the agent experiences the game.
31+
- `terminal_on_life_loss::Bool=false`: if `true`, then game is over whenever a life is lost.
32+
- `repeat_action_probability::Float64=0.`
33+
- `color_averaging::Bool=false`: whether to perform phosphor averaging or not.
34+
- `max_num_frames_per_episode::Int=0`
35+
- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO
36+
37+
See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
2338
"""
2439
function AtariEnv(
25-
name;
26-
colorspace = "Grayscale",
40+
;name = "pong",
41+
grayscale_obs=true,
42+
noop_max = 30,
2743
frame_skip = 4,
28-
noopmax = 20,
29-
color_averaging = true,
30-
actionset = :minimal,
31-
repeat_action_probability = 0.,
44+
terminal_on_life_loss=false,
45+
repeat_action_probability=0.,
46+
color_averaging=false,
47+
max_num_frames_per_episode=0,
48+
full_action_space=false,
49+
seed=nothing
3250
)
51+
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
52+
name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))
53+
54+
if isnothing(seed)
55+
seed = (MersenneTwister(), 0)
56+
elseif seed isa Tuple{Int, Int}
57+
seed = (MersenneTwister(seed[1]), seed[2])
58+
else
59+
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
60+
end
61+
3362
ale = ALE_new()
34-
setBool(ale, "color_averaging", color_averaging)
35-
setInt(ale, "frame_skip", Int32(frame_skip))
63+
setInt(ale, "random_seed", seed[2])
64+
setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism.
65+
setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode)
3666
setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability))
67+
setBool(ale, "color_averaging", color_averaging)
3768
loadROM(ale, name)
38-
observation_length = getScreenWidth(ale) * getScreenHeight(ale)
39-
if colorspace == "Grayscale"
40-
screen = Array{Cuchar}(undef, observation_length)
41-
getscreen! = ArcadeLearningEnvironment.getScreenGrayscale!
42-
observation_space = MultiDiscreteSpace(
43-
fill(typemax(Cuchar), observation_length),
44-
fill(typemin(Cuchar), observation_length),
45-
)
46-
elseif colorspace == "RGB"
47-
screen = Array{Cuchar}(undef, 3 * observation_length)
48-
getscreen! = ArcadeLearningEnvironment.getScreenRGB!
49-
observation_space = MultiDiscreteSpace(
50-
fill(typemax(Cuchar), 3 * observation_length),
51-
fill(typemin(Cuchar), 3 * observation_length),
52-
)
53-
elseif colorspace == "Raw"
54-
screen = Array{Cuchar}(undef, observation_length)
55-
getscreen! = ArcadeLearningEnvironment.getScreen!
56-
observation_space = MultiDiscreteSpace(
57-
fill(typemax(Cuchar), observation_length),
58-
fill(typemin(Cuchar), observation_length),
59-
)
60-
end
61-
actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale)
69+
70+
observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
71+
observation_space = MultiDiscreteSpace(
72+
fill(typemax(Cuchar), observation_size),
73+
fill(typemin(Cuchar), observation_size),
74+
)
75+
76+
actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
6277
action_space = DiscreteSpace(length(actions))
63-
AtariEnv(
78+
screens = (
79+
fill(typemin(Cuchar), observation_size),
80+
fill(typemin(Cuchar), observation_size),
81+
)
82+
83+
AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}(
6484
ale,
65-
screen,
66-
getscreen!,
85+
screens,
6786
actions,
6887
action_space,
6988
observation_space,
70-
noopmax,
89+
noop_max,
90+
frame_skip,
7191
0.0f0,
92+
lives(ale),
93+
seed[1]
7294
)
7395
end
7496

75-
function interact!(env::AtariEnv, a)
76-
env.reward = act(env.ale, env.actions[a])
77-
env.getscreen!(env.ale, env.screen)
97+
update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
98+
update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))
99+
100+
function interact!(env::AtariEnv{is_gray_scale, is_terminal_on_life_loss}, a) where {is_gray_scale, is_terminal_on_life_loss}
101+
r = 0.0f0
102+
103+
for i in 1:env.frame_skip
104+
r += act(env.ale, env.actions[a])
105+
if i == env.frame_skip
106+
update_screen!(env, env.screens[1])
107+
elseif i == env.frame_skip - 1
108+
update_screen!(env, env.screens[2])
109+
end
110+
end
111+
112+
# max-pooling
113+
if env.frame_skip > 1
114+
env.screens[1] .= max.(env.screens[1], env.screens[2])
115+
end
116+
117+
env.reward = r
78118
nothing
79119
end
80120

81-
observe(env::AtariEnv) =
82-
Observation(reward = env.reward, terminal = game_over(env.ale), state = env.screen)
121+
is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
122+
is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale)
123+
124+
observe(env::AtariEnv) = Observation(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
83125

84126
function reset!(env::AtariEnv)
85127
reset_game(env.ale)
86-
for _ = 1:rand(0:env.noopmax)
128+
for _ = 1:rand(env.seed, 0:env.noopmax)
87129
act(env.ale, Int32(0))
88130
end
89-
env.getscreen!(env.ale, env.screen)
131+
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
90132
env.reward = 0.0f0 # dummy
133+
env.lives = lives(env.ale)
91134
nothing
92135
end
93136

test/atari.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
@testset "atari" begin
2+
@testset "seed" begin
3+
env = AtariEnv(;name="pong", seed=(123,456))
4+
old_states = []
5+
actions = [rand(action_space(env)) for i in 1:10, j in 1:100]
6+
7+
for i in 1:10
8+
for j in 1:100
9+
interact!(env, actions[i, j])
10+
push!(old_states, copy(observe(env).state))
11+
end
12+
reset!(env)
13+
end
14+
15+
env = AtariEnv(;name="pong", seed=(123,456))
16+
new_states = []
17+
for i in 1:10
18+
for j in 1:100
19+
interact!(env, actions[i, j])
20+
push!(new_states, copy(observe(env).state))
21+
end
22+
reset!(env)
23+
end
24+
25+
@test old_states == new_states
26+
end
27+
28+
@testset "frame_skip" begin
29+
env = AtariEnv(;name="pong", frame_skip=4, seed=(123,456))
30+
states = []
31+
actions = [rand(action_space(env)) for _ in 1:100]
32+
33+
for i in 1:100
34+
interact!(env, actions[i])
35+
push!(states, copy(observe(env).state))
36+
end
37+
38+
env = AtariEnv(;name="pong", frame_skip=1, seed=(123,456))
39+
for i in 1:100
40+
interact!(env, actions[i])
41+
interact!(env, actions[i])
42+
interact!(env, actions[i])
43+
s1 = copy(observe(env).state)
44+
interact!(env, actions[i])
45+
s2 = copy(observe(env).state)
46+
@test states[i] == max.(s1, s2)
47+
end
48+
end
49+
50+
@testset "repeat_action_probability" begin
51+
env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456))
52+
states = []
53+
actions = [rand(action_space(env)) for _ in 1:100]
54+
for i in 1:100
55+
interact!(env, actions[i])
56+
push!(states, copy(observe(env).state))
57+
end
58+
59+
env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456))
60+
for i in 1:100
61+
interact!(env, actions[1])
62+
@test states[i] == observe(env).state
63+
end
64+
end
65+
66+
@testset "max_num_frames_per_episode" begin
67+
for i in 1:10
68+
env = AtariEnv(;name="pong", max_num_frames_per_episode=i, seed=(123,456))
69+
for _ in 1:i
70+
interact!(env, 1)
71+
end
72+
@test true == observe(env).terminal
73+
end
74+
end
75+
end

test/environments.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
:(deterministic_tree_MDP_with_rand_reward()),
6363
:(deterministic_tree_MDP()),
6464
:(deterministic_MDP()),
65-
(:(AtariEnv($x)) for x in atari_env_names)...,
65+
(:(AtariEnv(;name=$x)) for x in atari_env_names)...,
6666
(:(GymEnv($x)) for x in gym_env_names)...,
6767
]
6868

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ using Hanabi
1010

1111
include("spaces.jl")
1212
include("environments.jl")
13-
13+
include("atari.jl")
1414
end

0 commit comments

Comments
 (0)