Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

update atari env to use max-pooling #22

Merged
merged 3 commits into from
Nov 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/abstractenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ function action_space end
function observation_space end
function render end

"""
Observation(;reward, terminal, state, meta...)

The observation of an environment from the perspective of an agent.

# Keywords & Fields

- `reward`: the reward of an agent
- `terminal`: indicates that if the environment is terminated or not.
- `state`: the current state of the environment from the perspective of an agent
- `meta`: some other information, like `legal_actions`...

!!! note
The `reward` and `terminal` of the first observation before interacting with an environment may not be valid.
"""
struct Observation{R,T,S,M<:NamedTuple}
reward::R
terminal::T
Expand Down
153 changes: 98 additions & 55 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
@@ -1,93 +1,136 @@
using ArcadeLearningEnvironment, GR
using ArcadeLearningEnvironment, GR, Random

export AtariEnv

mutable struct AtariEnv{To,F} <: AbstractEnv
mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv
ale::Ptr{Nothing}
screen::Array{UInt8,1}
getscreen!::F
actions::Array{Int64,1}
screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling
actions::Vector{Int64}
action_space::DiscreteSpace{Int}
observation_space::To
observation_space::MultiDiscreteSpace{UInt8, N}
noopmax::Int
frame_skip::Int
reward::Float32
lives::Int
seed::S
end

"""
AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20,
color_averaging = true, repeat_action_probability = 0.)
Returns an AtariEnv that can be used in an RLSetup of the
[ReinforcementLearning](https://github.com/jbrea/ReinforcementLearning.jl)
package. Check the deps/roms folder of the ArcadeLearningEnvironment package to
see all available `name`s.
AtariEnv(;kwargs...)

This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009)

TODO: support seed! in single/multi thread

# Keywords

- `name::String="pong"`: name of the Atari environments. Use `getROMList` to show all supported environments.
- `grayscale_obs::Bool=true`:if `true`, then gray scale observation is returned, otherwise, RGB observation is returned.
- `noop_max::Int=30`: max number of no-ops.
- `frame_skip::Int=4`: the frequency at which the agent experiences the game.
- `terminal_on_life_loss::Bool=false`: if `true`, then game is over whenever a life is lost.
- `repeat_action_probability::Float64=0.`
- `color_averaging::Bool=false`: whether to perform phosphor averaging or not.
- `max_num_frames_per_episode::Int=0`
- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO

See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
"""
function AtariEnv(
name;
colorspace = "Grayscale",
;name = "pong",
grayscale_obs=true,
noop_max = 30,
frame_skip = 4,
noopmax = 20,
color_averaging = true,
actionset = :minimal,
repeat_action_probability = 0.,
terminal_on_life_loss=false,
repeat_action_probability=0.,
color_averaging=false,
max_num_frames_per_episode=0,
full_action_space=false,
seed=nothing
)
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names."))

if isnothing(seed)
seed = (MersenneTwister(), 0)
elseif seed isa Tuple{Int, Int}
seed = (MersenneTwister(seed[1]), seed[2])
else
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
end

ale = ALE_new()
setBool(ale, "color_averaging", color_averaging)
setInt(ale, "frame_skip", Int32(frame_skip))
setInt(ale, "random_seed", seed[2])
setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism.
setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode)
setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability))
setBool(ale, "color_averaging", color_averaging)
loadROM(ale, name)
observation_length = getScreenWidth(ale) * getScreenHeight(ale)
if colorspace == "Grayscale"
screen = Array{Cuchar}(undef, observation_length)
getscreen! = ArcadeLearningEnvironment.getScreenGrayscale!
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_length),
fill(typemin(Cuchar), observation_length),
)
elseif colorspace == "RGB"
screen = Array{Cuchar}(undef, 3 * observation_length)
getscreen! = ArcadeLearningEnvironment.getScreenRGB!
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), 3 * observation_length),
fill(typemin(Cuchar), 3 * observation_length),
)
elseif colorspace == "Raw"
screen = Array{Cuchar}(undef, observation_length)
getscreen! = ArcadeLearningEnvironment.getScreen!
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_length),
fill(typemin(Cuchar), observation_length),
)
end
actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale)

observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_size),
fill(typemin(Cuchar), observation_size),
)

actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
action_space = DiscreteSpace(length(actions))
AtariEnv(
screens = (
fill(typemin(Cuchar), observation_size),
fill(typemin(Cuchar), observation_size),
)

AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}(
ale,
screen,
getscreen!,
screens,
actions,
action_space,
observation_space,
noopmax,
noop_max,
frame_skip,
0.0f0,
lives(ale),
seed[1]
)
end

function interact!(env::AtariEnv, a)
env.reward = act(env.ale, env.actions[a])
env.getscreen!(env.ale, env.screen)
update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))

function interact!(env::AtariEnv{is_gray_scale, is_terminal_on_life_loss}, a) where {is_gray_scale, is_terminal_on_life_loss}
r = 0.0f0

for i in 1:env.frame_skip
r += act(env.ale, env.actions[a])
if i == env.frame_skip
update_screen!(env, env.screens[1])
elseif i == env.frame_skip - 1
update_screen!(env, env.screens[2])
end
end

# max-pooling
if env.frame_skip > 1
env.screens[1] .= max.(env.screens[1], env.screens[2])
end

env.reward = r
nothing
end

observe(env::AtariEnv) =
Observation(reward = env.reward, terminal = game_over(env.ale), state = env.screen)
is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale)

observe(env::AtariEnv) = Observation(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])

function reset!(env::AtariEnv)
reset_game(env.ale)
for _ = 1:rand(0:env.noopmax)
for _ = 1:rand(env.seed, 0:env.noopmax)
act(env.ale, Int32(0))
end
env.getscreen!(env.ale, env.screen)
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
env.reward = 0.0f0 # dummy
env.lives = lives(env.ale)
nothing
end

Expand Down
75 changes: 75 additions & 0 deletions test/atari.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
@testset "atari" begin
@testset "seed" begin
env = AtariEnv(;name="pong", seed=(123,456))
old_states = []
actions = [rand(action_space(env)) for i in 1:10, j in 1:100]

for i in 1:10
for j in 1:100
interact!(env, actions[i, j])
push!(old_states, copy(observe(env).state))
end
reset!(env)
end

env = AtariEnv(;name="pong", seed=(123,456))
new_states = []
for i in 1:10
for j in 1:100
interact!(env, actions[i, j])
push!(new_states, copy(observe(env).state))
end
reset!(env)
end

@test old_states == new_states
end

@testset "frame_skip" begin
env = AtariEnv(;name="pong", frame_skip=4, seed=(123,456))
states = []
actions = [rand(action_space(env)) for _ in 1:100]

for i in 1:100
interact!(env, actions[i])
push!(states, copy(observe(env).state))
end

env = AtariEnv(;name="pong", frame_skip=1, seed=(123,456))
for i in 1:100
interact!(env, actions[i])
interact!(env, actions[i])
interact!(env, actions[i])
s1 = copy(observe(env).state)
interact!(env, actions[i])
s2 = copy(observe(env).state)
@test states[i] == max.(s1, s2)
end
end

@testset "repeat_action_probability" begin
env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456))
states = []
actions = [rand(action_space(env)) for _ in 1:100]
for i in 1:100
interact!(env, actions[i])
push!(states, copy(observe(env).state))
end

env = AtariEnv(;name="pong", repeat_action_probability=1.0, seed=(123,456))
for i in 1:100
interact!(env, actions[1])
@test states[i] == observe(env).state
end
end

@testset "max_num_frames_per_episode" begin
for i in 1:10
env = AtariEnv(;name="pong", max_num_frames_per_episode=i, seed=(123,456))
for _ in 1:i
interact!(env, 1)
end
@test true == observe(env).terminal
end
end
end
2 changes: 1 addition & 1 deletion test/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
:(deterministic_tree_MDP_with_rand_reward()),
:(deterministic_tree_MDP()),
:(deterministic_MDP()),
(:(AtariEnv($x)) for x in atari_env_names)...,
(:(AtariEnv(;name=$x)) for x in atari_env_names)...,
(:(GymEnv($x)) for x in gym_env_names)...,
]

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ using Hanabi

include("spaces.jl")
include("environments.jl")

include("atari.jl")
end