Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit c325937

Browse files
authored
Update RLBase to v0.8 (#76)
* sync * support next version of RLBase * fix traits in OpenSpiel * change seed to rng * use one seed only to simplify api * set default log level in Atari env to :error * upgrade POMDP to v0.9 * rename TupleSpace to VectSpace * update README * rename render to display * fix test cases
1 parent afab24c commit c325937

File tree

17 files changed

+295
-293
lines changed

17 files changed

+295
-293
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1414
[compat]
1515
GR = "0.46, 0.47, 0.48, 0.49, 0.50, 0.51"
1616
OrdinaryDiffEq = "5"
17-
ReinforcementLearningBase = "0.7"
17+
ReinforcementLearningBase = "0.8"
1818
Requires = "1.0"
1919
StatsBase = "0.32, 0.33"
2020
julia = "1.3"

README.md

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@ By default, only some basic environments are installed. If you want to use some
2525
- MountainCarEnv
2626
- ContinuousMountainCarEnv
2727
- PendulumEnv
28+
- PendulumNonInteractiveEnv
2829

2930
### 3-rd Party Environments
3031

3132
| Environment Name | Dependent Package Name | Description |
3233
| :--- | :--- | :--- |
33-
| `AtariEnv` | [ArcadeLearningEnvironment.jl](https://github.com/JuliaReinforcementLearning/ArcadeLearningEnvironment.jl) | Tested only on Linux|
34-
| `ViZDoomEnv` | [ViZDoom.jl](https://github.com/JuliaReinforcementLearning/ViZDoom.jl) | Currently only a basic environment is supported. (By calling `basic_ViZDoom_env()`)|
35-
| `GymEnv` | [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) | Tested only on Linux |
36-
| `MDPEnv`,`POMDPEnv`| [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl)| The `get_observation_space` method is undefined|
34+
| `AtariEnv` | [ArcadeLearningEnvironment.jl](https://github.com/JuliaReinforcementLearning/ArcadeLearningEnvironment.jl) | |
35+
| `ViZDoomEnv` | [ViZDoom.jl](https://github.com/JuliaReinforcementLearning/ViZDoom.jl) | Broken [help wanted](https://github.com/JuliaReinforcementLearning/ViZDoom.jl/issues/7) |
36+
| `GymEnv` | [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) | |
37+
| `MDPEnv`,`POMDPEnv`| [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl)| Tested with `[email protected]`|
3738
| `OpenSpielEnv` | [OpenSpiel.jl](https://github.com/JuliaReinforcementLearning/OpenSpiel.jl) | |
3839

3940
## Usage
@@ -44,15 +45,52 @@ julia> using ReinforcementLearningEnvironments
4445
julia> using ReinforcementLearningBase
4546

4647
julia> env = CartPoleEnv()
47-
CartPoleEnv{Float64}(gravity=9.8,masscart=1.0,masspole=0.1,totalmass=1.1,halflength=0.5,polemasslength=0.05,forcemag=10.0,tau=0.02,thetathreshold=0.20943951023931953,xthreshold=2.4,max_steps=200)
48+
# CartPoleEnv
4849

49-
julia> action_space = get_action_space(env)
50+
## Traits
51+
52+
| Trait Type | Value |
53+
|:---------------- | --------------------:|
54+
| NumAgentStyle | SingleAgent() |
55+
| DynamicStyle | Sequential() |
56+
| InformationStyle | PerfectInformation() |
57+
| ChanceStyle | Deterministic() |
58+
| RewardStyle | StepReward() |
59+
| UtilityStyle | GeneralSum() |
60+
| ActionStyle | MinimalActionSet() |
61+
62+
## Actions
63+
64+
DiscreteSpace{UnitRange{Int64}}(1:2)
65+
66+
## Players
67+
68+
* `DEFAULT_PLAYER`
69+
70+
## Current Player
71+
72+
`DEFAULT_PLAYER`
73+
74+
## Is Environment Terminated?
75+
76+
No
77+
78+
julia> get_state(env)
79+
4-element Array{Float64,1}:
80+
0.02688439956517477
81+
-0.0003235577964125977
82+
0.019563124862911535
83+
-0.01897808522860225
84+
85+
julia> actions = get_actions(env)
5086
DiscreteSpace{UnitRange{Int64}}(1:2)
5187

5288
julia> while true
53-
action = rand(action_space)
54-
env(action)
55-
obs = observe(env)
56-
get_terminal(obs) && break
89+
env(rand(actions))
90+
get_terminal(env) && break
5791
end
5892
```
93+
94+
## Application
95+
96+
Checkout [atari.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningZoo.jl/blob/master/src/experiments/atari.jl) for some more complicated cases on how to use these environments and the [wrappers](https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/blob/master/src/implementations/environments.jl) provided in [ReinforcementLearningBase.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl).

src/environments/atari.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ using .ArcadeLearningEnvironment
66
77
This implementation follows the guidelines in [Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents](https://arxiv.org/abs/1709.06009)
88
9-
TODO: support seed! in single/multi thread
10-
119
# Keywords
1210
1311
- `name::String="pong"`: name of the Atari environments. Use `ReinforcementLearningEnvironments.list_atari_rom_names()` to show all supported environments.
@@ -19,6 +17,8 @@ TODO: support seed! in single/multi thread
1917
- `color_averaging::Bool=false`: whether to perform phosphor averaging or not.
2018
- `max_num_frames_per_episode::Int=0`
2119
- `full_action_space::Bool=false`: by default, only use minimal action set. If `true`, one need to call `legal_actions` to get the valid action set. TODO
20+
- `seed::Int` is used to set the initial seed of the underlying C environment and the rng used by the this wrapper environment to initialize the number of no-op steps at the beginning of each episode.
21+
- `log_level::Symbol`, `:info`, `:warning` or `:error`. Default value is `:error`.
2222
2323
See also the [python implementation](https://github.com/openai/gym/blob/c072172d64bdcd74313d97395436c592dc836d5c/gym/wrappers/atari_preprocessing.py#L8-L36)
2424
"""
@@ -34,25 +34,24 @@ function AtariEnv(;
3434
max_num_frames_per_episode = 0,
3535
full_action_space = false,
3636
seed = nothing,
37+
log_level = :error
3738
)
3839
frame_skip > 0 || throw(ArgumentError("frame_skip must be greater than 0!"))
3940
name in getROMList() ||
4041
throw(ArgumentError("unknown ROM name.\n\nRun `ReinforcementLearningEnvironments.list_atari_rom_names()` to see all the game names."))
4142

43+
ale = ALE_new()
4244
if isnothing(seed)
43-
seed = (MersenneTwister(), 0)
44-
elseif seed isa Tuple{Int,Int}
45-
seed = (MersenneTwister(seed[1]), seed[2])
45+
rng = Random.GLOBAL_RNG
4646
else
47-
@error "You must specify two seeds, one for Julia wrapper, one for internal C implementation" # ??? maybe auto generate two seed from one
47+
setInt(ale, "random_seed", Int32(seed % typemax(Int32)))
48+
rng = MersenneTwister(hash(seed+1))
4849
end
49-
50-
ale = ALE_new()
51-
setInt(ale, "random_seed", seed[2])
5250
setInt(ale, "frame_skip", Int32(1)) # !!! do not use internal frame_skip here, we need to apply max-pooling for the latest two frames, so we need to manually implement the mechanism.
5351
setInt(ale, "max_num_frames_per_episode", max_num_frames_per_episode)
5452
setFloat(ale, "repeat_action_probability", Float32(repeat_action_probability))
5553
setBool(ale, "color_averaging", color_averaging)
54+
setLoggerMode!(log_level)
5655
loadROM(ale, name)
5756

5857
observation_size = grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) :
@@ -68,8 +67,9 @@ function AtariEnv(;
6867
(fill(typemin(Cuchar), observation_size), fill(typemin(Cuchar), observation_size))
6968

7069
env =
71-
AtariEnv{grayscale_obs,terminal_on_life_loss,grayscale_obs ? 2 : 3,typeof(seed[1])}(
70+
AtariEnv{grayscale_obs,terminal_on_life_loss,grayscale_obs ? 2 : 3,typeof(rng)}(
7271
ale,
72+
name,
7373
screens,
7474
actions,
7575
action_space,
@@ -78,7 +78,7 @@ function AtariEnv(;
7878
frame_skip,
7979
0.0f0,
8080
lives(ale),
81-
seed[1],
81+
rng,
8282
)
8383
finalizer(env) do x
8484
ALE_del(x.ale)
@@ -117,12 +117,15 @@ end
117117
is_terminal(env::AtariEnv{<:Any,true}) = game_over(env.ale) || (lives(env.ale) < env.lives)
118118
is_terminal(env::AtariEnv{<:Any,false}) = game_over(env.ale)
119119

120-
RLBase.observe(env::AtariEnv) =
121-
(reward = env.reward, terminal = is_terminal(env), state = env.screens[1])
120+
RLBase.get_name(env::AtariEnv) = "AtariEnv($(env.name))"
121+
RLBase.get_actions(env::AtariEnv) = env.action_space
122+
RLBase.get_reward(env::AtariEnv) = env.reward
123+
RLBase.get_terminal(env::AtariEnv) = is_terminal(env)
124+
RLBase.get_state(env::AtariEnv) = env.screens[1]
122125

123126
function RLBase.reset!(env::AtariEnv)
124127
reset_game(env.ale)
125-
for _ in 1:rand(env.seed, 0:env.noopmax)
128+
for _ in 1:rand(env.rng, 0:env.noopmax)
126129
act(env.ale, Int32(0))
127130
end
128131
update_screen!(env, env.screens[1]) # no need to update env.screens[2]
@@ -148,7 +151,7 @@ function imshowcolor(x::AbstractArray{UInt8,1}, dims)
148151
updatews()
149152
end
150153

151-
function RLBase.render(env::AtariEnv)
154+
function Base.display(env::AtariEnv)
152155
x = getScreenRGB(env.ale)
153156
imshowcolor(x, (Int(getScreenWidth(env.ale)), Int(getScreenHeight(env.ale))))
154157
end

src/environments/classic_control/acrobot.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function AcrobotEnv(;
7878
g = T(9.8),
7979
dt = T(0.2),
8080
max_steps = 200,
81-
seed = nothing,
81+
rng = Random.GLOBAL_RNG,
8282
book_or_nips = "book",
8383
avail_torque = [T(-1.0), T(0.0), T(1.0)],
8484
)
@@ -108,7 +108,7 @@ function AcrobotEnv(;
108108
0,
109109
false,
110110
0,
111-
MersenneTwister(seed),
111+
rng,
112112
T(0.0),
113113
book_or_nips,
114114
[T(-1.0), T(0.0), T(1.0)],
@@ -119,8 +119,10 @@ end
119119

120120
acrobot_observation(s) = [cos(s[1]), sin(s[1]), cos(s[2]), sin(s[2]), s[3], s[4]]
121121

122-
RLBase.observe(env::AcrobotEnv) =
123-
(reward = env.reward, state = acrobot_observation(env.state), terminal = env.done)
122+
RLBase.get_actions(env::AcrobotEnv) = env.action_space
123+
RLBase.get_terminal(env::AcrobotEnv) = env.done
124+
RLBase.get_state(env::AcrobotEnv) = acrobot_observation(env.state)
125+
RLBase.get_reward(env::AcrobotEnv) = env.reward
124126

125127
function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}
126128
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)

src/environments/classic_control/cartpole.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Base.show(io::IO, env::CartPoleEnv{T}) where {T} =
4545
- `forcemag = T(10.0)`
4646
- `max_steps = 200`
4747
- 'dt = 0.02'
48-
- `seed = nothing`
48+
- `rng = Random.GLOBAL_RNG`
4949
"""
5050
function CartPoleEnv(;
5151
T = Float64,
@@ -56,7 +56,7 @@ function CartPoleEnv(;
5656
forcemag = 10.0,
5757
max_steps = 200,
5858
dt = 0.02,
59-
seed = nothing,
59+
rng = Random.GLOBAL_RNG,
6060
)
6161
params = CartPoleEnvParams{T}(
6262
gravity,
@@ -80,7 +80,7 @@ function CartPoleEnv(;
8080
2,
8181
false,
8282
0,
83-
MersenneTwister(seed),
83+
rng,
8484
)
8585
reset!(cp)
8686
cp
@@ -96,8 +96,10 @@ function RLBase.reset!(env::CartPoleEnv{T}) where {T<:Number}
9696
nothing
9797
end
9898

99-
RLBase.observe(env::CartPoleEnv{T}) where {T} =
100-
(reward = env.done ? zero(T) : one(T), terminal = env.done, state = env.state)
99+
RLBase.get_actions(env::CartPoleEnv) = env.action_space
100+
RLBase.get_reward(env::CartPoleEnv{T}) where {T} = env.done ? zero(T) : one(T)
101+
RLBase.get_terminal(env::CartPoleEnv) = env.done
102+
RLBase.get_state(env::CartPoleEnv) = env.state
101103

102104
function (env::CartPoleEnv)(a)
103105
@assert a in (1, 2)
@@ -136,7 +138,7 @@ function plotendofepisode(x, y, d)
136138
end
137139
return nothing
138140
end
139-
function RLBase.render(env::CartPoleEnv)
141+
function Base.display(env::CartPoleEnv)
140142
s, a, d = env.state, env.action, env.done
141143
x, xdot, theta, thetadot = s
142144
l = 2 * env.params.halflength

src/environments/classic_control/mountain_car.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
5353
- `T = Float64`
5454
- `continuous = false`
55-
- `seed = nothing`
55+
- `rng = Random.GLOBAL_RNG`
5656
- `min_pos = -1.2`
5757
- `max_pos = 0.6`
5858
- `max_speed = 0.07`
@@ -62,7 +62,7 @@ end
6262
- `power = 0.001`
6363
- `gravity = 0.0025`
6464
"""
65-
function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwargs...)
65+
function MountainCarEnv(; T = Float64, continuous = false, rng = Random.GLOBAL_RNG, kwargs...)
6666
if continuous
6767
params = MountainCarEnvParams(; goal_pos = 0.45, power = 0.0015, T = T, kwargs...)
6868
else
@@ -80,7 +80,7 @@ function MountainCarEnv(; T = Float64, continuous = false, seed = nothing, kwarg
8080
rand(action_space),
8181
false,
8282
0,
83-
MersenneTwister(seed),
83+
rng,
8484
)
8585
reset!(env)
8686
env
@@ -90,8 +90,10 @@ ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwar
9090

9191
Random.seed!(env::MountainCarEnv, seed) = Random.seed!(env.rng, seed)
9292

93-
RLBase.observe(env::MountainCarEnv) =
94-
(reward = env.done ? 0.0 : -1.0, terminal = env.done, state = env.state)
93+
RLBase.get_actions(env::MountainCarEnv) = env.action_space
94+
RLBase.get_reward(env::MountainCarEnv{A,T}) where {A, T} = env.done ? zero(T) : -one(T)
95+
RLBase.get_terminal(env::MountainCarEnv) = env.done
96+
RLBase.get_state(env::MountainCarEnv) = env.state
9597

9698
function RLBase.reset!(env::MountainCarEnv{A,T}) where {A,T}
9799
env.state[1] = 0.2 * rand(env.rng, T) - 0.6
@@ -135,7 +137,7 @@ end
135137
height(xs) = sin(3 * xs) * 0.45 + 0.55
136138
rotate(xs, ys, θ) = xs * cos(θ) - ys * sin(θ), ys * cos(θ) + xs * sin(θ)
137139
translate(xs, ys, t) = xs .+ t[1], ys .+ t[2]
138-
function RLBase.render(env::MountainCarEnv)
140+
function Base.display(env::MountainCarEnv)
139141
s = env.state
140142
d = env.done
141143
clearws()

src/environments/classic_control/pendulum.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
- `max_steps = 200`
3939
- `continuous::Bool = true`
4040
- `n_actions::Int = 3`
41-
- `seed = nothing`
41+
- `rng = Random.GLOBAL_RNG`
4242
"""
4343
function PendulumEnv(;
4444
T = Float64,
@@ -51,7 +51,7 @@ function PendulumEnv(;
5151
max_steps = 200,
5252
continuous::Bool = true,
5353
n_actions::Int = 3,
54-
seed = nothing,
54+
rng = Random.GLOBAL_RNG
5555
)
5656
high = T.([1, 1, max_speed])
5757
action_space = continuous ? ContinuousSpace(-2.0, 2.0) : DiscreteSpace(n_actions)
@@ -62,7 +62,7 @@ function PendulumEnv(;
6262
zeros(T, 2),
6363
false,
6464
0,
65-
MersenneTwister(seed),
65+
rng,
6666
zero(T),
6767
n_actions,
6868
rand(action_space),
@@ -76,8 +76,10 @@ Random.seed!(env::PendulumEnv, seed) = Random.seed!(env.rng, seed)
7676
pendulum_observation(s) = [cos(s[1]), sin(s[1]), s[2]]
7777
angle_normalize(x) = Base.mod((x + Base.π), (2 * Base.π)) - Base.π
7878

79-
RLBase.observe(env::PendulumEnv) =
80-
(reward = env.reward, state = pendulum_observation(env.state), terminal = env.done)
79+
RLBase.get_actions(env::PendulumEnv) = env.action_space
80+
RLBase.get_reward(env::PendulumEnv) = env.reward
81+
RLBase.get_terminal(env::PendulumEnv) = env.done
82+
RLBase.get_state(env::PendulumEnv) = pendulum_observation(env.state)
8183

8284
function RLBase.reset!(env::PendulumEnv{A,T}) where {A,T}
8385
env.state[1] = 2 * π * (rand(env.rng, T) .- 1)

0 commit comments

Comments
 (0)