|
1 |
| -using ArcadeLearningEnvironment, GR |
| 1 | +using ArcadeLearningEnvironment, GR, Random |
2 | 2 |
|
3 | 3 | export AtariEnv
|
4 | 4 |
|
5 |
| -mutable struct AtariEnv{To,F} <: AbstractEnv |
| 5 | +mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: AbstractEnv |
6 | 6 | ale::Ptr{Nothing}
|
7 |
| - screen::Array{UInt8,1} |
8 |
| - getscreen!::F |
9 |
| - actions::Array{Int64,1} |
| 7 | + screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling |
| 8 | + actions::Vector{Int64} |
10 | 9 | action_space::DiscreteSpace{Int}
|
11 |
| - observation_space::To |
| 10 | + observation_space::MultiDiscreteSpace{UInt8, N} |
12 | 11 | noopmax::Int
|
| 12 | + frame_skip::Int |
13 | 13 | reward::Float32
|
| 14 | + lives::Int |
| 15 | + seed::S |
14 | 16 | end
|
15 | 17 |
|
16 | 18 | """
|
17 |
| - AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20, |
18 |
| - color_averaging = true, repeat_action_probability = 0.) |
19 |
| -Returns an AtariEnv that can be used in an RLSetup of the |
20 |
| -[ReinforcementLearning](https://github.com/jbrea/ReinforcementLearning.jl) |
21 |
| -package. Check the deps/roms folder of the ArcadeLearningEnvironment package to |
22 |
| -see all available `name`s. |
| 19 | + AtariEnv(;kwargs...) |
| 20 | +
|
| 21 | +This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009) |
| 22 | +
|
| 23 | +TODO: support seed! in single/multi thread |
| 24 | +
|
| 25 | +# Keywords |
| 26 | +
|
| 27 | +- `name::String="pong"`: name of the Atari environments. Use `getROMList` to show all supported environments. |
| 28 | +- `grayscale_obs::Bool=true`:if `true`, then gray scale observation is returned, otherwise, RGB observation is returned. |
| 29 | +- `noop_max::Int=30`: max number of no-ops. |
| 30 | +- `frame_skip::Int=4`: the frequency at which the agent experiences the game. |
| 31 | +- `terminal_on_life_loss::Bool=false`: if `true`, then game is over whenever a life is lost. |
| 32 | +- `repeat_action_probability::Float64=0.` |
| 33 | +- `color_averaging::Bool=false`: whether to perform phosphor averaging or not. |
| 34 | +- `max_num_frames_per_episode::Int=0` |
| 35 | +- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO |
| 36 | +
|
| 37 | +See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36) |
23 | 38 | """
|
24 | 39 | function AtariEnv(
|
25 |
| - name; |
26 |
| - colorspace = "Grayscale", |
| 40 | + ;name = "pong", |
| 41 | + grayscale_obs=true, |
| 42 | + noop_max = 30, |
27 | 43 | frame_skip = 4,
|
28 |
| - noopmax = 20, |
29 |
| - color_averaging = true, |
30 |
| - actionset = :minimal, |
31 |
| - repeat_action_probability = 0., |
| 44 | + terminal_on_life_loss=false, |
| 45 | + repeat_action_probability=0., |
| 46 | + color_averaging=false, |
| 47 | + max_num_frames_per_episode=0, |
| 48 | + full_action_space=false, |
| 49 | + seed=nothing |
32 | 50 | )
|
| 51 | + frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!")) |
| 52 | + name in getROMList() || throw(ArgumentError("unknown ROM name! run `getROMList()` to see all the game names.")) |
| 53 | + |
| 54 | + if isnothing(seed) |
| 55 | + seed = (MersenneTwister(), 0) |
| 56 | + elseif seed isa Tuple{Int, Int} |
| 57 | + seed = (MersenneTwister(seed[1]), seed[2]) |
| 58 | + else |
| 59 | + @error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one |
| 60 | + end |
| 61 | + |
33 | 62 | ale = ALE_new()
|
34 |
| - setBool(ale, "color_averaging", color_averaging) |
35 |
| - setInt(ale, "frame_skip", Int32(frame_skip)) |
| 63 | + setInt(ale, "random_seed", seed[2]) |
| 64 | + setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism. |
| 65 | + setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode) |
36 | 66 | setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability))
|
| 67 | + setBool(ale, "color_averaging", color_averaging) |
37 | 68 | loadROM(ale, name)
|
38 |
| - observation_length = getScreenWidth(ale) * getScreenHeight(ale) |
39 |
| - if colorspace == "Grayscale" |
40 |
| - screen = Array{Cuchar}(undef, observation_length) |
41 |
| - getscreen! = ArcadeLearningEnvironment.getScreenGrayscale! |
42 |
| - observation_space = MultiDiscreteSpace( |
43 |
| - fill(typemax(Cuchar), observation_length), |
44 |
| - fill(typemin(Cuchar), observation_length), |
45 |
| - ) |
46 |
| - elseif colorspace == "RGB" |
47 |
| - screen = Array{Cuchar}(undef, 3 * observation_length) |
48 |
| - getscreen! = ArcadeLearningEnvironment.getScreenRGB! |
49 |
| - observation_space = MultiDiscreteSpace( |
50 |
| - fill(typemax(Cuchar), 3 * observation_length), |
51 |
| - fill(typemin(Cuchar), 3 * observation_length), |
52 |
| - ) |
53 |
| - elseif colorspace == "Raw" |
54 |
| - screen = Array{Cuchar}(undef, observation_length) |
55 |
| - getscreen! = ArcadeLearningEnvironment.getScreen! |
56 |
| - observation_space = MultiDiscreteSpace( |
57 |
| - fill(typemax(Cuchar), observation_length), |
58 |
| - fill(typemin(Cuchar), observation_length), |
59 |
| - ) |
60 |
| - end |
61 |
| - actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale) |
| 69 | + |
| 70 | + observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order |
| 71 | + observation_space = MultiDiscreteSpace( |
| 72 | + fill(typemax(Cuchar), observation_size), |
| 73 | + fill(typemin(Cuchar), observation_size), |
| 74 | + ) |
| 75 | + |
| 76 | + actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale) |
62 | 77 | action_space = DiscreteSpace(length(actions))
|
63 |
| - AtariEnv( |
| 78 | + screens = ( |
| 79 | + fill(typemin(Cuchar), observation_size), |
| 80 | + fill(typemin(Cuchar), observation_size), |
| 81 | + ) |
| 82 | + |
| 83 | + AtariEnv{grayscale_obs, terminal_on_life_loss, grayscale_obs ? 2 : 3, typeof(seed[1])}( |
64 | 84 | ale,
|
65 |
| - screen, |
66 |
| - getscreen!, |
| 85 | + screens, |
67 | 86 | actions,
|
68 | 87 | action_space,
|
69 | 88 | observation_space,
|
70 |
| - noopmax, |
| 89 | + noop_max, |
| 90 | + frame_skip, |
71 | 91 | 0.0f0,
|
| 92 | + lives(ale), |
| 93 | + seed[1] |
72 | 94 | )
|
73 | 95 | end
|
74 | 96 |
|
75 |
| -function interact!(env::AtariEnv, a) |
76 |
| - env.reward = act(env.ale, env.actions[a]) |
77 |
| - env.getscreen!(env.ale, env.screen) |
| 97 | +update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen)) |
| 98 | +update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen)) |
| 99 | + |
| 100 | +function interact!(env::AtariEnv{is_gray_scale, is_terminal_on_life_loss}, a) where {is_gray_scale, is_terminal_on_life_loss} |
| 101 | + r = 0.0f0 |
| 102 | + |
| 103 | + for i in 1:env.frame_skip |
| 104 | + r += act(env.ale, env.actions[a]) |
| 105 | + if i == env.frame_skip |
| 106 | + update_screen!(env, env.screens[1]) |
| 107 | + elseif i == env.frame_skip - 1 |
| 108 | + update_screen!(env, env.screens[2]) |
| 109 | + end |
| 110 | + end |
| 111 | + |
| 112 | + # max-pooling |
| 113 | + if env.frame_skip > 1 |
| 114 | + env.screens[1] .= max.(env.screens[1], env.screens[2]) |
| 115 | + end |
| 116 | + |
| 117 | + env.reward = r |
78 | 118 | nothing
|
79 | 119 | end
|
80 | 120 |
|
81 |
| -observe(env::AtariEnv) = |
82 |
| - Observation(reward = env.reward, terminal = game_over(env.ale), state = env.screen) |
| 121 | +is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives) |
| 122 | +is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale) |
| 123 | + |
| 124 | +observe(env::AtariEnv) = Observation(reward = env.reward, terminal = is_terminal(env), state = env.screens[1]) |
83 | 125 |
|
84 | 126 | function reset!(env::AtariEnv)
|
85 | 127 | reset_game(env.ale)
|
86 |
| - for _ = 1:rand(0:env.noopmax) |
| 128 | + for _ = 1:rand(env.seed, 0:env.noopmax) |
87 | 129 | act(env.ale, Int32(0))
|
88 | 130 | end
|
89 |
| - env.getscreen!(env.ale, env.screen) |
| 131 | + update_screen!(env, env.screens[1]) # no need to update env.screens[2] |
90 | 132 | env.reward = 0.0f0 # dummy
|
| 133 | + env.lives = lives(env.ale) |
91 | 134 | nothing
|
92 | 135 | end
|
93 | 136 |
|
|
0 commit comments