Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 155050d

Browse files
jbreafindmyway
authored andcommitted
Add ContinuousMountainCar (#8)
* Int64 -> Int * add continuous mountain car
1 parent 9c4a511 commit 155050d

File tree

8 files changed

+63
-42
lines changed

8 files changed

+63
-42
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ By default, only some basic environments are installed. If you want to use some
3838
3939
- CartPoleEnv
4040
- MountainCarEnv
41+
- ContinuousMountainCarEnv
4142
- PendulumEnv
4243
- MDPEnv
4344
- POMDPEnv

src/environments/atari.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ struct AtariEnv{To,F} <: AbstractEnv
99
actions::Array{Int32, 1}
1010
action_space::DiscreteSpace{Int}
1111
observation_space::To
12-
noopmax::Int64
12+
noopmax::Int
1313
end
1414

1515
action_space(env::AtariEnv) = env.action_space

src/environments/classic_control/cart_pole.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct CartPoleEnvParams{T}
1414
tau::T
1515
thetathreshold::T
1616
xthreshold::T
17-
max_steps::Int64
17+
max_steps::Int
1818
end
1919

2020
mutable struct CartPoleEnv{T, R<:AbstractRNG} <: AbstractEnv
@@ -24,7 +24,7 @@ mutable struct CartPoleEnv{T, R<:AbstractRNG} <: AbstractEnv
2424
state::Array{T, 1}
2525
action::Int
2626
done::Bool
27-
t::Int64
27+
t::Int
2828
rng::R
2929
end
3030

src/environments/classic_control/mdp.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ MDPEnv(model; rng=Random.GLOBAL_RNG) = MDPEnv(
6363
action_space(env::Union{MDPEnv, POMDPEnv}) = env.action_space
6464
observation_space(env::Union{MDPEnv, POMDPEnv}) = env.observation_space
6565

66-
observationindex(env, o) = Int64(o) + 1
66+
observationindex(env, o) = Int(o) + 1
6767

6868
function reset!(env::Union{POMDPEnv, MDPEnv})
6969
initialstate(env.model, env.rng)
@@ -89,13 +89,13 @@ end
8989
#####
9090
"""
9191
mutable struct SimpleMDPEnv
92-
ns::Int64
93-
na::Int64
94-
state::Int64
92+
ns::Int
93+
na::Int
94+
state::Int
9595
trans_probs::Array{AbstractArray, 2}
9696
reward::R
97-
initialstates::Array{Int64, 1}
98-
isterminal::Array{Int64, 1}
97+
initialstates::Array{Int, 1}
98+
isterminal::Array{Int, 1}
9999
rng::S
100100
A Markov Decision Process with `ns` states, `na` actions, current `state`,
101101
`na`x`ns` - array of transition probabilites `trans_props` which consists for
@@ -110,11 +110,11 @@ probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref
110110
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG}
111111
observation_space::DiscreteSpace
112112
action_space::DiscreteSpace
113-
state::Int64
113+
state::Int
114114
trans_probs::Array{T, 2}
115115
reward::R
116-
initialstates::Array{Int64, 1}
117-
isterminal::Array{Int64, 1}
116+
initialstates::Array{Int, 1}
117+
isterminal::Array{Int, 1}
118118
rng::S
119119
end
120120

@@ -186,10 +186,10 @@ expected_rewards(r::NormalStateActionReward, ::Any) = r.mean
186186

187187
# run SimpleMDPEnv
188188
"""
189-
run!(mdp::SimpleMDPEnv, action::Int64)
189+
run!(mdp::SimpleMDPEnv, action::Int)
190190
Transition to a new state given `action`. Returns the new state.
191191
"""
192-
function run!(mdp::SimpleMDPEnv, action::Int64)
192+
function run!(mdp::SimpleMDPEnv, action::Int)
193193
if mdp.isterminal[mdp.state] == 1
194194
reset!(mdp)
195195
else
@@ -199,9 +199,9 @@ function run!(mdp::SimpleMDPEnv, action::Int64)
199199
end
200200

201201
"""
202-
run!(mdp::SimpleMDPEnv, policy::Array{Int64, 1}) = run!(mdp, policy[mdp.state])
202+
run!(mdp::SimpleMDPEnv, policy::Array{Int, 1}) = run!(mdp, policy[mdp.state])
203203
"""
204-
run!(mdp::SimpleMDPEnv, policy::Array{Int64, 1}) = run!(mdp, policy[mdp.state])
204+
run!(mdp::SimpleMDPEnv, policy::Array{Int, 1}) = run!(mdp, policy[mdp.state])
205205

206206

207207
function interact!(env::SimpleMDPEnv, action)

src/environments/classic_control/mountain_car.jl

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,81 @@
11
using Random
22
using GR
33

4-
export MountainCarEnv
4+
export MountainCarEnv, ContinuousMountainCarEnv
55

66
struct MountainCarEnvParams{T}
77
min_pos::T
88
max_pos::T
99
max_speed::T
1010
goal_pos::T
11-
max_steps::Int64
11+
goal_velocity::T
12+
power::T
13+
gravity::T
14+
max_steps::Int
15+
end
16+
function MountainCarEnvParams(; T = Float64, min_pos = -1.2, max_pos = .6,
17+
max_speed = .07, goal_pos = .5, max_steps = 200,
18+
goal_velocity = .0, power = .001, gravity = .0025)
19+
MountainCarEnvParams{T}(min_pos, max_pos, max_speed, goal_pos,
20+
goal_velocity, power, gravity, max_steps)
1221
end
1322

14-
mutable struct MountainCarEnv{T, R<:AbstractRNG} <: AbstractEnv
23+
mutable struct MountainCarEnv{A, T, R<:AbstractRNG} <: AbstractEnv
1524
params::MountainCarEnvParams{T}
16-
action_space::DiscreteSpace
25+
action_space::A
1726
observation_space::MultiContinuousSpace{(2,), 1}
1827
state::Array{T, 1}
19-
action::Int64
28+
action::Int
2029
done::Bool
21-
t::Int64
30+
t::Int
2231
rng::R
2332
end
2433

25-
function MountainCarEnv(; T = Float64, min_pos = T(-1.2), max_pos = T(.6),
26-
max_speed = T(.07), goal_pos = T(.5), max_steps = 200)
27-
env = MountainCarEnv(MountainCarEnvParams(min_pos, max_pos, max_speed, goal_pos, max_steps),
28-
DiscreteSpace(3),
29-
MultiContinuousSpace([min_pos, -max_speed], [max_pos, max_speed]),
30-
zeros(T, 2),
31-
1,
32-
false,
33-
0,
34-
Random.GLOBAL_RNG)
34+
function MountainCarEnv(; T = Float64, continuous = false,
35+
rng = Random.GLOBAL_RNG, kwargs...)
36+
if continuous
37+
params = MountainCarEnvParams(; goal_pos = .45, power = .0015, T = T, kwargs...)
38+
else
39+
params = MountainCarEnvParams(; kwargs...)
40+
end
41+
env = MountainCarEnv(params,
42+
continuous ? ContinuousSpace(-T(1.), T(1.)) : DiscreteSpace(3),
43+
MultiContinuousSpace([params.min_pos, -params.max_speed],
44+
[params.max_pos, params.max_speed]),
45+
zeros(T, 2),
46+
1,
47+
false,
48+
0,
49+
rng)
3550
reset!(env)
3651
env
3752
end
53+
ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwargs...)
3854

3955
action_space(env::MountainCarEnv) = env.action_space
4056
observation_space(env::MountainCarEnv) = env.observation_space
4157
observe(env::MountainCarEnv) = (observation=env.state, isdone=env.done)
4258

43-
function reset!(env::MountainCarEnv{T}) where T
59+
function reset!(env::MountainCarEnv{A, T}) where {A, T}
4460
env.state[1] = .2 * rand(env.rng, T) - .6
4561
env.state[2] = 0.
4662
env.done = false
4763
env.t = 0
4864
nothing
4965
end
5066

51-
function interact!(env::MountainCarEnv, a)
67+
interact!(env::MountainCarEnv{<:ContinuousSpace}, a) = _interact!(env, min(max(a, -1, 1)))
68+
interact!(env::MountainCarEnv{<:DiscreteSpace}, a) = _interact!(env, a - 2)
69+
function _interact!(env::MountainCarEnv, force)
5270
env.t += 1
5371
x, v = env.state
54-
v += (a - 2)*0.001 + cos(3*x)*(-0.0025)
72+
v += force * env.params.power + cos(3*x)*(-env.params.gravity)
5573
v = clamp(v, -env.params.max_speed, env.params.max_speed)
5674
x += v
5775
x = clamp(x, env.params.min_pos, env.params.max_pos)
5876
if x == env.params.min_pos && v < 0 v = 0 end
59-
env.done = x >= env.params.goal_pos || env.t >= env.params.max_steps
77+
env.done = x >= env.params.goal_pos && v >= env.params.goal_velocity ||
78+
env.t >= env.params.max_steps
6079
env.state[1] = x
6180
env.state[2] = v
6281
(observation=env.state, reward=-1., isdone=env.done)
@@ -87,6 +106,6 @@ function render(env::MountainCarEnv)
87106
xs, ys = rotate(xs, ys, θ)
88107
xs, ys = translate(xs, ys, [x, height(x)])
89108
fillarea(xs, ys)
90-
plotendofepisode(env.params.max_pos + .1, 0, d)
109+
plotendofepisode(env.params.max_pos + .1, 0, d)
91110
updatews()
92-
end
111+
end

src/environments/classic_control/pendulum.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ struct PendulumEnvParams{T}
99
m::T
1010
l::T
1111
dt::T
12-
max_steps::Int64
12+
max_steps::Int
1313
end
1414

1515
mutable struct PendulumEnv{T, R<:AbstractRNG} <: AbstractEnv
@@ -18,7 +18,7 @@ mutable struct PendulumEnv{T, R<:AbstractRNG} <: AbstractEnv
1818
observation_space::MultiContinuousSpace{(3,), 1}
1919
state::Array{T, 1}
2020
done::Bool
21-
t::Int64
21+
t::Int
2222
rng::R
2323
end
2424

src/environments/hanabi.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ mutable struct HanabiEnv <: AbstractEnv
9191
state::Base.RefValue{Hanabi.LibHanabi.PyHanabiState}
9292
moves::Vector{Base.RefValue{Hanabi.LibHanabi.PyHanabiMove}}
9393
observation_encoder::Base.RefValue{Hanabi.LibHanabi.PyHanabiObservationEncoder}
94-
observation_space::MultiDiscreteSpace{Int64, 1}
95-
action_space::DiscreteSpace{Int64}
94+
observation_space::MultiDiscreteSpace{Int, 1}
95+
action_space::DiscreteSpace{Int}
9696
reward::HanabiResult
9797

9898
function HanabiEnv(;kw...)

test/environments.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
:(basic_ViZDoom_env()),
5454
:(CartPoleEnv()),
5555
:(MountainCarEnv()),
56+
:(ContinuousMountainCarEnv()),
5657
:(PendulumEnv()),
5758
:(MDPEnv(LegacyGridWorld())),
5859
:(POMDPEnv(TigerPOMDP())),

0 commit comments

Comments
 (0)