Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 6d23062

Browse files
committed
adapt to new ALE version
1 parent 69316a1 commit 6d23062

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

src/environments/atari.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ using ArcadeLearningEnvironment, GR
22

33
export AtariEnv
44

5-
struct AtariEnv{To} <: AbstractEnv
5+
struct AtariEnv{To,F} <: AbstractEnv
66
ale::Ptr{Nothing}
77
screen::Array{UInt8, 1}
8-
getscreen::Function
8+
getscreen!::F
99
actions::Array{Int32, 1}
1010
action_space::DiscreteSpace{Int}
1111
observation_space::To
@@ -38,44 +38,37 @@ function AtariEnv(name;
3838
observation_length = getScreenWidth(ale) * getScreenHeight(ale)
3939
if colorspace == "Grayscale"
4040
screen = Array{Cuchar}(undef, observation_length)
41-
getscreen = getScreenGrayscale
41+
getscreen! = ArcadeLearningEnvironment.getScreenGrayscale!
4242
observation_space = MultiDiscreteSpace(fill(typemax(Cuchar), observation_length), fill(typemin(Cuchar), observation_length))
4343
elseif colorspace == "RGB"
4444
screen = Array{Cuchar}(undef, 3*observation_length)
45-
getscreen = getScreenRGB
45+
getscreen! = ArcadeLearningEnvironment.getScreenRGB!
4646
observation_space = MultiDiscreteSpace(fill(typemax(Cuchar), 3*observation_length), fill(typemin(Cuchar), 3*observation_length))
4747
elseif colorspace == "Raw"
4848
screen = Array{Cuchar}(undef, observation_length)
49-
getscreen = getScreen
49+
getscreen! = ArcadeLearningEnvironment.getScreen!
5050
observation_space = MultiDiscreteSpace(fill(typemax(Cuchar), observation_length), fill(typemin(Cuchar), observation_length))
5151
end
5252
actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale)
5353
action_space = DiscreteSpace(length(actions))
54-
AtariEnv(ale, screen, getscreen, actions, action_space, observation_space, noopmax)
55-
end
56-
57-
function getScreen(p::Ptr, s::Array{Cuchar, 1})
58-
sraw = getScreen(p)
59-
for i in 1:length(s)
60-
s[i] = sraw[i] .>> 1
61-
end
54+
AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax)
6255
end
6356

6457
function interact!(env::AtariEnv, a)
6558
r = act(env.ale, env.actions[a])
66-
env.getscreen(env.ale, env.screen)
59+
env.getscreen!(env.ale, env.screen)
6760
(observation=env.screen, reward=r, isdone=game_over(env.ale))
6861
end
6962

7063
function observe(env::AtariEnv)
71-
env.getscreen(env.ale, env.screen)
64+
env.getscreen!(env.ale, env.screen)
7265
(observation=env.screen, isdone=game_over(env.ale))
7366
end
7467

7568
function reset!(env::AtariEnv)
7669
reset_game(env.ale)
7770
for _ in 1:rand(0:env.noopmax) act(env.ale, Int32(0)) end
78-
env.getscreen(env.ale, env.screen)
71+
env.getscreen!(env.ale, env.screen)
7972
nothing
8073
end
8174

@@ -98,11 +91,9 @@ function imshowcolor(x::Array{UInt8, 1}, dims)
9891
end
9992

10093
function render(env::AtariEnv)
101-
x = zeros(UInt8, 3 * 160 * 210)
102-
getScreenRGB(env.ale, x)
103-
imshowcolor(x, (160, 210))
94+
x = getScreenRGB(env.ale)
95+
imshowcolor(x, (Int(getScreenWidth(env.ale)),
96+
Int(getScreenHeight(env.ale))))
10497
end
10598

106-
function list_atari_rom_names()
107-
[splitext(x)[1] for x in readdir(joinpath(dirname(pathof(ArcadeLearningEnvironment)), "..", "deps", "roms"))]
108-
end
99+
list_atari_rom_names() = getROMList()

0 commit comments

Comments
 (0)