Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

use RLBase instead #29

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ language: julia
os:
- linux
julia:
- 1.2
- 1.3
notifications:
email: false

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM julia:1.2
FROM julia:1.3

# install dependencies
RUN set -eux; \
Expand Down
29 changes: 9 additions & 20 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,22 @@ authors = ["Jun Tian <[email protected]>"]
version = "0.1.3"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = "1.3"
GR = "0.46"
Requires = "1.0"

[extras]
ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977"
Hanabi = "705708ad-e62c-5f47-9095-732127600058"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ViZDoom = "13bb3beb-38fe-5ca7-9a46-050a216300b1"

[targets]
test = ["Test", "ArcadeLearningEnvironment", "ViZDoom", "PyCall", "Hanabi"]

[compat]
julia = "1"
Distributions = "^0"
GR = "^0"
POMDPModels = "^0"
POMDPs = "^0"
Reexport = "^0"
Requires = "^0"
StatsBase = "^0"
test = ["Test", "ArcadeLearningEnvironment", "PyCall", "POMDPModels", "POMDPs"]
8 changes: 0 additions & 8 deletions benchmarks/Project.toml

This file was deleted.

100 changes: 0 additions & 100 deletions benchmarks/speed_of_random_action.jl

This file was deleted.

48 changes: 0 additions & 48 deletions benchmarks/speed_of_random_action.md

This file was deleted.

10 changes: 4 additions & 6 deletions src/ReinforcementLearningEnvironments.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
module ReinforcementLearningEnvironments

using ReinforcementLearningBase

export RLEnvs
const RLEnvs = ReinforcementLearningEnvironments

using Reexport, Requires

include("abstractenv.jl")
include("spaces/spaces.jl")
using Requires

# built-in environments
include("environments/classic_control/classic_control.jl")

# dynamic loading environments
function __init__()
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("environments/atari.jl")
@require ViZDoom = "13bb3beb-38fe-5ca7-9a46-050a216300b1" include("environments/vizdoom.jl")
@require PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" include("environments/gym.jl")
@require Hanabi = "705708ad-e62c-5f47-9095-732127600058" include("environments/hanabi.jl")
@require POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" include("environments/mdp.jl")
end

end # module
59 changes: 0 additions & 59 deletions src/abstractenv.jl

This file was deleted.

14 changes: 7 additions & 7 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ mutable struct AtariEnv{IsGrayScale, TerminalOnLifeLoss, N, S<:AbstractRNG} <: A
ale::Ptr{Nothing}
screens::Tuple{Array{UInt8, N}, Array{UInt8, N}} # for max-pooling
actions::Vector{Int64}
action_space::DiscreteSpace{Int}
observation_space::MultiDiscreteSpace{UInt8, N}
action_space::DiscreteSpace{UnitRange{Int}}
observation_space::MultiDiscreteSpace{Array{UInt8, N}}
noopmax::Int
frame_skip::Int
reward::Float32
Expand Down Expand Up @@ -69,8 +69,8 @@ function AtariEnv(

observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) : (3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
observation_space = MultiDiscreteSpace(
fill(typemax(Cuchar), observation_size),
fill(typemin(Cuchar), observation_size),
fill(typemax(Cuchar), observation_size),
)

actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
Expand All @@ -97,7 +97,7 @@ end
update_screen!(env::AtariEnv{true}, screen) = ArcadeLearningEnvironment.getScreenGrayscale!(env.ale, vec(screen))
update_screen!(env::AtariEnv{false}, screen) = ArcadeLearningEnvironment.getScreenRGB!(env.ale, vec(screen))

function interact!(env::AtariEnv{is_gray_scale, is_terminal_on_life_loss}, a) where {is_gray_scale, is_terminal_on_life_loss}
function (env::AtariEnv{is_gray_scale, is_terminal_on_life_loss})(a) where {is_gray_scale, is_terminal_on_life_loss}
r = 0.0f0

for i in 1:env.frame_skip
Expand All @@ -121,9 +121,9 @@ end
is_terminal(env::AtariEnv{<:Any, true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
is_terminal(env::AtariEnv{<:Any, false}) = game_over(env.ale)

observe(env::AtariEnv) = Observation(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
RLBase.observe(env::AtariEnv) = (reward = env.reward, terminal = is_terminal(env), state = env.screens[1])

function reset!(env::AtariEnv)
function RLBase.reset!(env::AtariEnv)
reset_game(env.ale)
for _ = 1:rand(env.seed, 0:env.noopmax)
act(env.ale, Int32(0))
Expand Down Expand Up @@ -152,7 +152,7 @@ function imshowcolor(x::Array{UInt8,1}, dims)
updatews()
end

function render(env::AtariEnv)
function RLBase.render(env::AtariEnv)
x = getScreenRGB(env.ale)
imshowcolor(x, (Int(getScreenWidth(env.ale)), Int(getScreenHeight(env.ale))))
end
Expand Down
Loading