Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 1e82202

Browse files
authored
move AcrobotEnv optional to reduce first loading time (#139)
1 parent db5a792 commit 1e82202

File tree

7 files changed

+73
-69
lines changed

7 files changed

+73
-69
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
88
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
99
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1010
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
11-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1211
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1312
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1413
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -18,7 +17,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1817
GR = "0.46, 0.47, 0.48, 0.49, 0.50, 0.51, 0.52, 0.53, 0.54, 0.55"
1918
IntervalSets = "0.5"
2019
MacroTools = "0.5"
21-
OrdinaryDiffEq = "5"
2220
ReinforcementLearningBase = "0.9.2"
2321
Requires = "1.0"
2422
StatsBase = "0.32, 0.33"

src/ReinforcementLearningEnvironments.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Random
55
using GR
66
using Requires
77
using IntervalSets
8-
using Base.Threads: @spawn
8+
using Base.Threads:@spawn
99
using Markdown
1010

1111
const RLEnvs = ReinforcementLearningEnvironments
@@ -29,6 +29,11 @@ function __init__()
2929
@require SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429" include(
3030
"environments/3rd_party/snake.jl",
3131
)
32+
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" include(
33+
"environments/3rd_party/AcrobotEnv.jl",
34+
)
35+
36+
3237
end
3338

3439
end # module

src/environments/examples/AcrobotEnv.jl renamed to src/environments/3rd_party/AcrobotEnv.jl

Lines changed: 21 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,3 @@
1-
import OrdinaryDiffEq
2-
3-
export AcrobotEnv
4-
5-
struct AcrobotEnvParams{T}
6-
link_length_a::T # [m]
7-
link_length_b::T # [m]
8-
link_mass_a::T #: [kg] mass of link 1
9-
link_mass_b::T #: [kg] mass of link 2
10-
#: [m] position of the center of mass of link 1
11-
link_com_pos_a::T
12-
#: [m] position of the center of mass of link 2
13-
link_com_pos_b::T
14-
#: Rotation related parameters
15-
link_moi::T
16-
max_torque_noise::T
17-
#: [m/s] maximum velocity of link 1
18-
max_vel_a::T
19-
#: [m/s] maximum velocity of link 2
20-
max_vel_b::T
21-
#: [m/s2] acceleration due to gravity
22-
g::T
23-
#: [s] timestep
24-
dt::T
25-
#: maximum steps in episode
26-
max_steps::Int
27-
end
28-
29-
mutable struct AcrobotEnv{T,R<:AbstractRNG} <: AbstractEnv
30-
params::AcrobotEnvParams{T}
31-
state::Vector{T}
32-
action::Int
33-
done::Bool
34-
t::Int
35-
rng::R
36-
reward::T
37-
# difference in second link angular acceleration equation
38-
# as per python gym
39-
book_or_nips::String
40-
# array of available torques based on actions
41-
avail_torque::Vector{T}
42-
end
43-
441
"""
452
AcrobotEnv(;kwargs...)
463
# Keyword arguments
@@ -61,23 +18,23 @@ AcrobotEnv(;kwargs...)
6118
- `avail_torque = [T(-1.), T(0.), T(1.)]`
6219
"""
6320
function AcrobotEnv(;
64-
T = Float64,
65-
link_length_a = T(1.0),
66-
link_length_b = T(1.0),
67-
link_mass_a = T(1.0),
68-
link_mass_b = T(1.0),
69-
link_com_pos_a = T(0.5),
70-
link_com_pos_b = T(0.5),
71-
link_moi = T(1.0),
72-
max_torque_noise = T(0.0),
73-
max_vel_a = T(4 * π),
74-
max_vel_b = T(9 * π),
75-
g = T(9.8),
76-
dt = T(0.2),
77-
max_steps = 200,
78-
rng = Random.GLOBAL_RNG,
79-
book_or_nips = "book",
80-
avail_torque = [T(-1.0), T(0.0), T(1.0)],
21+
T=Float64,
22+
link_length_a=T(1.0),
23+
link_length_b=T(1.0),
24+
link_mass_a=T(1.0),
25+
link_mass_b=T(1.0),
26+
link_com_pos_a=T(0.5),
27+
link_com_pos_b=T(0.5),
28+
link_moi=T(1.0),
29+
max_torque_noise=T(0.0),
30+
max_vel_a=T(4 * π),
31+
max_vel_b=T(9 * π),
32+
g=T(9.8),
33+
dt=T(0.2),
34+
max_steps=200,
35+
rng=Random.GLOBAL_RNG,
36+
book_or_nips="book",
37+
avail_torque=[T(-1.0), T(0.0), T(1.0)],
8138
)
8239

8340
params = AcrobotEnvParams{T}(
@@ -124,7 +81,7 @@ RLBase.is_terminated(env::AcrobotEnv) = env.done
12481
RLBase.state(env::AcrobotEnv) = acrobot_observation(env.state)
12582
RLBase.reward(env::AcrobotEnv) = env.reward
12683

127-
function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}
84+
function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number}
12885
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)
12986
env.t = 0
13087
env.action = 2
@@ -133,7 +90,7 @@ function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}
13390
end
13491

13592
# governing equations as per python gym
136-
function (env::AcrobotEnv{T})(a) where {T<:Number}
93+
function (env::AcrobotEnv{T})(a) where {T <: Number}
13794
env.action = a
13895
env.t += 1
13996
torque = env.avail_torque[a]
@@ -178,7 +135,7 @@ function dsdt(du, s_augmented, env::AcrobotEnv, t)
178135

179136
# extract action and state
180137
a = s_augmented[end]
181-
s = s_augmented[1:end-1]
138+
s = s_augmented[1:end - 1]
182139

183140
# writing in standard form
184141
theta1 = s[1]
@@ -242,7 +199,7 @@ function wrap(x, m, M)
242199
while x < m
243200
x = x + diff
244201
end
245-
return x
202+
return x
246203
end
247204

248205
function bound(x, m, M)

src/environments/3rd_party/structs.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct GymEnv{T,Ta,To,P} <: AbstractEnv
66
end
77
export GymEnv
88

9-
mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv
9+
mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S <: AbstractRNG} <: AbstractEnv
1010
ale::Ptr{Nothing}
1111
name::String
1212
screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling
@@ -38,3 +38,46 @@ mutable struct SnakeGameEnv{A,N,G} <: AbstractEnv
3838
is_terminated::Bool
3939
end
4040
export SnakeGameEnv
41+
42+
struct AcrobotEnvParams{T}
43+
link_length_a::T # [m]
44+
link_length_b::T # [m]
45+
link_mass_a::T # : [kg] mass of link 1
46+
link_mass_b::T # : [kg] mass of link 2
47+
# : [m] position of the center of mass of link 1
48+
link_com_pos_a::T
49+
# : [m] position of the center of mass of link 2
50+
link_com_pos_b::T
51+
# : Rotation related parameters
52+
link_moi::T
53+
max_torque_noise::T
54+
# : [m/s] maximum velocity of link 1
55+
max_vel_a::T
56+
# : [m/s] maximum velocity of link 2
57+
max_vel_b::T
58+
# : [m/s2] acceleration due to gravity
59+
g::T
60+
# : [s] timestep
61+
dt::T
62+
# : maximum steps in episode
63+
max_steps::Int
64+
end
65+
66+
export AcrobotEnvParams
67+
68+
mutable struct AcrobotEnv{T,R <: AbstractRNG} <: AbstractEnv
69+
params::AcrobotEnvParams{T}
70+
state::Vector{T}
71+
action::Int
72+
done::Bool
73+
t::Int
74+
rng::R
75+
reward::T
76+
# difference in second link angular acceleration equation
77+
# as per python gym
78+
book_or_nips::String
79+
# array of available torques based on actions
80+
avail_torque::Vector{T}
81+
end
82+
83+
export AcrobotEnv

src/environments/examples/examples.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ include("TicTacToeEnv.jl")
77
include("TinyHanabiEnv.jl")
88
include("PigEnv.jl")
99
include("KuhnPokerEnv.jl")
10-
include("AcrobotEnv.jl")
1110
include("CartPoleEnv.jl")
1211
include("MountainCarEnv.jl")
1312
include("PendulumEnv.jl")

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977"
33
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
45
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
56
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
67
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using OpenSpiel
88
using Random
99
using StableRNGs
1010
using Statistics
11+
using OrdinaryDiffEq
1112

1213
@testset "ReinforcementLearningEnvironments" begin
1314
include("environments/environments.jl")

0 commit comments

Comments
 (0)