Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 0413be4

Browse files
Format .jl files (#104)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 56e0551 commit 0413be4

29 files changed

+238
-203
lines changed

src/base.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using MacroTools: @forward
66

77
using IntervalSets
88

9-
Random.rand(s::Union{Interval, Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)
9+
Random.rand(s::Union{Interval,Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)
1010

1111
function Random.rand(rng::AbstractRNG, s::Interval)
1212
rand(rng) * (s.right - s.left) + s.left
@@ -26,7 +26,7 @@ struct WorldSpace{T} end
2626

2727
WorldSpace() = WorldSpace{Any}()
2828

29-
Base.in(x, ::WorldSpace{T}) where T = x isa T
29+
Base.in(x, ::WorldSpace{T}) where {T} = x isa T
3030

3131
#####
3232
# ZeroTo
@@ -39,16 +39,16 @@ Similar to `Base.OneTo`. Useful when wrapping third-party environments.
3939
"""
4040
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
4141
stop::T
42-
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T)-one(T),n))
42+
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T) - one(T), n))
4343
end
4444

4545
ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)
4646

4747
Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
48-
Base.length(r::ZeroTo{T}) where T = T(r.stop + one(r.stop))
49-
Base.first(r::ZeroTo{T}) where T = zero(r.stop)
48+
Base.length(r::ZeroTo{T}) where {T} = T(r.stop + one(r.stop))
49+
Base.first(r::ZeroTo{T}) where {T} = zero(r.stop)
5050

51-
function getindex(v::ZeroTo{T}, i::Integer) where T
51+
function getindex(v::ZeroTo{T}, i::Integer) where {T}
5252
Base.@_inline_meta
5353
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
5454
convert(T, i)
@@ -76,15 +76,16 @@ Base.similar(s::Space, args...) = Space(similar(s.s, args...))
7676

7777
Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)
7878

79-
Random.rand(rng::AbstractRNG, s::Space) = map(s.s) do x
80-
rand(rng, x)
81-
end
79+
Random.rand(rng::AbstractRNG, s::Space) =
80+
map(s.s) do x
81+
rand(rng, x)
82+
end
8283

83-
Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k=>rand(rng,v) for (k,v) in s.s)
84+
Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k => rand(rng, v) for (k, v) in s.s)
8485

8586
function Base.in(X, S::Space)
8687
if length(X) == length(S.s)
87-
for (x,s) in zip(X, S.s)
88+
for (x, s) in zip(X, S.s)
8889
if x s
8990
return false
9091
end

src/converters.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/environments/3rd_party/atari.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ function AtariEnv(;
5757
observation_size =
5858
grayscale_obs ? (getScreenWidth(ale), getScreenHeight(ale)) :
5959
(3, getScreenWidth(ale), getScreenHeight(ale)) # !!! note the order
60-
observation_space = Space(
61-
ClosedInterval{Cuchar}.(
62-
fill(typemin(Cuchar), observation_size),
63-
fill(typemax(Cuchar), observation_size),
64-
)
65-
)
60+
observation_space = Space(ClosedInterval{
61+
Cuchar,
62+
}.(
63+
fill(typemin(Cuchar), observation_size),
64+
fill(typemax(Cuchar), observation_size),
65+
))
6666

6767
actions = full_action_space ? getLegalActionSet(ale) : getMinimalActionSet(ale)
6868
action_space = Base.OneTo(length(actions))
@@ -165,16 +165,14 @@ end
165165

166166
function Base.show(io::IO, m::MIME"image/png", env::AtariEnv)
167167
x = getScreenRGB(env.ale)
168-
p=imshowcolor(x, (Int(getScreenWidth(env.ale)), Int(getScreenHeight(env.ale))))
168+
p = imshowcolor(x, (Int(getScreenWidth(env.ale)), Int(getScreenHeight(env.ale))))
169169
show(io, m, p)
170170
end
171171

172172
Base.show(io::IO, t::MIME"text/plain", env::AbstractEnv) = show(
173-
IOContext(
174-
io,
175-
:is_show_state => false,
176-
:is_show_state_space => false),
173+
IOContext(io, :is_show_state => false, :is_show_state_space => false),
177174
MIME"text/markdown"(),
178-
env)
175+
env,
176+
)
179177

180178
list_atari_rom_names() = getROMList()

src/environments/3rd_party/gym.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ Random.seed!(env::GymEnv, s) = env.pyenv.seed(s)
9696
function space_transform(s::PyObject)
9797
spacetype = s.__class__.__name__
9898
if spacetype == "Box"
99-
Space(ClosedInterval.(s.low,s.high))
99+
Space(ClosedInterval.(s.low, s.high))
100100
elseif spacetype == "Discrete" # for GymEnv("CliffWalking-v0"), `s.n` is of type PyObject (numpy.int64)
101101
ZeroTo(py"int($s.n)" - 1)
102102
elseif spacetype == "MultiBinary"

src/environments/3rd_party/open_spiel.jl

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ using StatsBase: sample, weights
4444
`True` or `False` (instead of `true` or `false`). Another approach is to just
4545
specify parameters in `kwargs` in the Julia style.
4646
"""
47-
function OpenSpielEnv(
48-
name="kuhn_poker";
49-
kwargs...,
50-
)
47+
function OpenSpielEnv(name = "kuhn_poker"; kwargs...)
5148
game = load_game(String(name); kwargs...)
5249
state = new_initial_state(game)
5350
OpenSpielEnv(state, game)
@@ -65,21 +62,21 @@ RLBase.players(env::OpenSpielEnv) = 0:(num_players(env.game)-1)
6562

6663
function RLBase.action_space(env::OpenSpielEnv, player)
6764
if player == chance_player(env)
68-
[k for (k,v) in chance_outcomes(env.state)]
65+
[k for (k, v) in chance_outcomes(env.state)]
6966
else
70-
ZeroTo(num_distinct_actions(env.game)-1)
67+
ZeroTo(num_distinct_actions(env.game) - 1)
7168
end
7269
end
7370

7471
function RLBase.legal_action_space(env::OpenSpielEnv, player)
7572
if player == chance_player(env)
76-
[k for (k,v) in chance_outcomes(env.state)]
73+
[k for (k, v) in chance_outcomes(env.state)]
7774
else
7875
legal_actions(env.state, player)
7976
end
8077
end
8178

82-
RLBase.prob(env::OpenSpielEnv, player) = [v for (k,v) in chance_outcomes(env.state)]
79+
RLBase.prob(env::OpenSpielEnv, player) = [v for (k, v) in chance_outcomes(env.state)]
8380

8481
function RLBase.legal_action_space_mask(env::OpenSpielEnv, player)
8582
n =
@@ -107,7 +104,7 @@ end
107104

108105
function RLBase.state(env::OpenSpielEnv, ss::RLBase.AbstractStateStyle, player)
109106
if player < 0 # TODO: revisit this in [email protected]
110-
@warn "unexpected player $player, falling back to default state value." maxlog=1
107+
@warn "unexpected player $player, falling back to default state value." maxlog = 1
111108
s = state_space(env)
112109
if s isa WorldSpace
113110
""
@@ -119,13 +116,28 @@ function RLBase.state(env::OpenSpielEnv, ss::RLBase.AbstractStateStyle, player)
119116
end
120117
end
121118

122-
_state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) = information_state_string(env.state, player)
123-
_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = information_state_tensor(env.state, player)
124-
_state(env::OpenSpielEnv, ::Observation{String}, player) = observation_string(env.state, player)
125-
_state(env::OpenSpielEnv, ::Observation{Array}, player) = observation_tensor(env.state, player)
126-
127-
RLBase.state_space(env::OpenSpielEnv, ::Union{InformationSet{String},Observation{String}}, p) = WorldSpace{AbstractString}()
128-
RLBase.state_space(env::OpenSpielEnv, ::Union{InformationSet{Array},Observation{Array}}, p) = Space(fill(typemin(Float64)..typemax(Float64), information_state_tensor_size(env.state)))
119+
_state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) =
120+
information_state_string(env.state, player)
121+
_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) =
122+
information_state_tensor(env.state, player)
123+
_state(env::OpenSpielEnv, ::Observation{String}, player) =
124+
observation_string(env.state, player)
125+
_state(env::OpenSpielEnv, ::Observation{Array}, player) =
126+
observation_tensor(env.state, player)
127+
128+
RLBase.state_space(
129+
env::OpenSpielEnv,
130+
::Union{InformationSet{String},Observation{String}},
131+
p,
132+
) = WorldSpace{AbstractString}()
133+
RLBase.state_space(
134+
env::OpenSpielEnv,
135+
::Union{InformationSet{Array},Observation{Array}},
136+
p,
137+
) = Space(fill(
138+
typemin(Float64)..typemax(Float64),
139+
information_state_tensor_size(env.state),
140+
))
129141

130142
Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently."
131143

@@ -154,10 +166,16 @@ function RLBase.UtilityStyle(env::OpenSpielEnv)
154166
end
155167

156168
RLBase.ActionStyle(env::OpenSpielEnv) = FULL_ACTION_SET
157-
RLBase.DynamicStyle(env::OpenSpielEnv) = dynamics(get_type(env.game))== OpenSpiel.SEQUENTIAL ? RLBase.SEQUENTIAL : RLBase.SIMULTANEOUS
158-
RLBase.InformationStyle(env::OpenSpielEnv) = information(get_type(env.game)) ==OpenSpiel.PERFECT_INFORMATION ? RLBase.PERFECT_INFORMATION : RLBase.IMPERFECT_INFORMATION
169+
RLBase.DynamicStyle(env::OpenSpielEnv) =
170+
dynamics(get_type(env.game)) == OpenSpiel.SEQUENTIAL ? RLBase.SEQUENTIAL :
171+
RLBase.SIMULTANEOUS
172+
RLBase.InformationStyle(env::OpenSpielEnv) =
173+
information(get_type(env.game)) == OpenSpiel.PERFECT_INFORMATION ?
174+
RLBase.PERFECT_INFORMATION : RLBase.IMPERFECT_INFORMATION
159175
RLBase.NumAgentStyle(env::OpenSpielEnv) = MultiAgent(num_players(env.game))
160-
RLBase.RewardStyle(env::OpenSpielEnv) = reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD : RLBase.TERMINAL_REWARD
176+
RLBase.RewardStyle(env::OpenSpielEnv) =
177+
reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD :
178+
RLBase.TERMINAL_REWARD
161179

162180
RLBase.StateStyle(env::OpenSpielEnv) = (
163181
RLBase.InformationSet{String}(),

src/environments/environments.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
include("examples/examples.jl")
22
include("non_interactive/non_interactive.jl")
33
include("wrappers/wrappers.jl")
4-
include("3rd_party/structs.jl")
4+
include("3rd_party/structs.jl")

src/environments/examples/AcrobotEnv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ acrobot_observation(s) = [cos(s[1]), sin(s[1]), cos(s[2]), sin(s[2]), s[3], s[4]
115115

116116
RLBase.action_space(env::AcrobotEnv) = Base.OneTo(3)
117117

118-
function RLBase.state_space(env::AcrobotEnv{T}) where T
118+
function RLBase.state_space(env::AcrobotEnv{T}) where {T}
119119
high = [1.0, 1.0, 1.0, 1.0, env.params.max_vel_a, env.params.max_vel_b]
120120
Space(ClosedInterval{T}.(-high, high))
121121
end

src/environments/examples/CartPoleEnv.jl

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,7 @@ function CartPoleEnv(;
6666
2.4,
6767
max_steps,
6868
)
69-
high =
70-
cp = CartPoleEnv(
71-
params,
72-
zeros(T, 4),
73-
2,
74-
false,
75-
0,
76-
rng,
77-
)
69+
high = cp = CartPoleEnv(params, zeros(T, 4), 2, false, 0, rng)
7870
reset!(cp)
7971
cp
8072
end
@@ -91,14 +83,12 @@ end
9183

9284
RLBase.action_space(env::CartPoleEnv) = Base.OneTo(2)
9385

94-
RLBase.state_space(env::CartPoleEnv{T}) where T = Space(
95-
ClosedInterval{T}[
96-
(-2 * env.params.xthreshold)..(2 * env.params.xthreshold),
97-
-1e38..1e38,
98-
(-2 * env.params.thetathreshold)..(2 * env.params.thetathreshold),
99-
-1e38..1e38
100-
]
101-
)
86+
RLBase.state_space(env::CartPoleEnv{T}) where {T} = Space(ClosedInterval{T}[
87+
(-2*env.params.xthreshold)..(2*env.params.xthreshold),
88+
-1e38..1e38,
89+
(-2*env.params.thetathreshold)..(2*env.params.thetathreshold),
90+
-1e38..1e38,
91+
])
10292

10393
RLBase.reward(env::CartPoleEnv{T}) where {T} = env.done ? zero(T) : one(T)
10494
RLBase.is_terminated(env::CartPoleEnv) = env.done

src/environments/examples/KuhnPokerEnv.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ RLBase.state_space(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, p
103103
KUHN_POKER_STATES
104104

105105
RLBase.action_space(env::KuhnPokerEnv, ::Int) = Base.OneTo(length(KUHN_POKER_ACTIONS))
106-
RLBase.action_space(env::KuhnPokerEnv, ::ChancePlayer) = Base.OneTo(length(KUHN_POKER_CARDS))
106+
RLBase.action_space(env::KuhnPokerEnv, ::ChancePlayer) =
107+
Base.OneTo(length(KUHN_POKER_CARDS))
107108

108109
RLBase.legal_action_space(env::KuhnPokerEnv, p::ChancePlayer) =
109110
[x for x in action_space(env, p) if KUHN_POKER_CARDS[x] env.cards]

src/environments/examples/TicTacToeEnv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ RLBase.current_player(env::TicTacToeEnv) = env.player
7575
RLBase.players(env::TicTacToeEnv) = (CROSS, NOUGHT)
7676

7777
RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
78-
RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = Space(fill(false..true, 3, 3, 3))
79-
RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) = get_tic_tac_toe_state_info()[env].index
78+
RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) =
79+
Space(fill(false..true, 3, 3, 3))
80+
RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) =
81+
get_tic_tac_toe_state_info()[env].index
8082
RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
8183
Base.OneTo(length(get_tic_tac_toe_state_info()))
8284

src/environments/examples/TinyHanabiEnv.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ RLBase.action_space(env::TinyHanabiEnv, ::Int) = Base.OneTo(3)
5656
RLBase.action_space(env::TinyHanabiEnv, ::ChancePlayer) = Base.OneTo(2)
5757

5858
RLBase.legal_action_space(env::TinyHanabiEnv, ::ChancePlayer) = findall(!in(env.cards), 1:2)
59-
RLBase.legal_action_space_mask(env::TinyHanabiEnv, ::ChancePlayer) = [x env.cards for x in 1:2]
59+
RLBase.legal_action_space_mask(env::TinyHanabiEnv, ::ChancePlayer) =
60+
[x env.cards for x in 1:2]
6061

6162
function RLBase.prob(env::TinyHanabiEnv, ::ChancePlayer)
6263
if isempty(env.cards)

src/environments/non_interactive/pendulum.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ Random.seed!(env::PendulumNonInteractiveEnv, seed) = Random.seed!(env.rng, seed)
6969
RLBase.reward(env::PendulumNonInteractiveEnv) = 0
7070
RLBase.is_terminated(env::PendulumNonInteractiveEnv) = env.done
7171
RLBase.state(env::PendulumNonInteractiveEnv) = env.state
72-
RLBase.state_space(env::PendulumNonInteractiveEnv{T}) where T = Space([typemin(T)..typemax(T), typemin(T)..typemax(T)])
72+
RLBase.state_space(env::PendulumNonInteractiveEnv{T}) where {T} =
73+
Space([typemin(T)..typemax(T), typemin(T)..typemax(T)])
7374

7475
function RLBase.reset!(env::PendulumNonInteractiveEnv{Fl}) where {Fl}
7576
env.state .= (Fl(2 * pi) * rand(env.rng, Fl), randn(env.rng, Fl))

src/environments/wrappers/ActionTransformedEnv.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ end
1313
`legal_action_space(env)`. `action_mapping` will be applied to `action` before
1414
feeding it into `env`.
1515
"""
16-
function ActionTransformedEnv(env; action_space_mapping=identity, action_mapping=identity)
16+
function ActionTransformedEnv(
17+
env;
18+
action_space_mapping = identity,
19+
action_mapping = identity,
20+
)
1721
ActionTransformedEnv(action_space_mapping, action_mapping, env)
1822
end
1923

@@ -25,9 +29,13 @@ for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
2529
end
2630

2731
RLBase.state(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
28-
RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)
32+
RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) =
33+
state_space(env.env, ss)
2934

30-
RLBase.action_space(env::ActionTransformedEnv) = env.action_space_mapping(action_space(env.env))
31-
RLBase.legal_action_space(env::ActionTransformedEnv) = env.action_space_mapping(legal_action_space(env.env))
35+
RLBase.action_space(env::ActionTransformedEnv) =
36+
env.action_space_mapping(action_space(env.env))
37+
RLBase.legal_action_space(env::ActionTransformedEnv) =
38+
env.action_space_mapping(legal_action_space(env.env))
3239

33-
(env::ActionTransformedEnv)(action, args...; kwargs...) = env.env(env.action_mapping(action), args...; kwargs...)
40+
(env::ActionTransformedEnv)(action, args...; kwargs...) =
41+
env.env(env.action_mapping(action), args...; kwargs...)

src/environments/wrappers/DefaultStateStyle.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
2020
end
2121
end
2222

23-
(env::DefaultStateStyleEnv)(args...;kwargs...) = env.env(args...;kwargs...)
23+
(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...)
2424

2525
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
26-
RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)
26+
RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
27+
state_space(env.env, ss)

src/environments/wrappers/MaxTimeoutEnv.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ end
2121

2222
for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
2323
if f != :terminal
24-
@eval RLBase.$f(x::MaxTimeoutEnv, args...; kwargs...) = $f(x.env, args...; kwargs...)
24+
@eval RLBase.$f(x::MaxTimeoutEnv, args...; kwargs...) =
25+
$f(x.env, args...; kwargs...)
2526
end
2627
end
2728

28-
RLBase.is_terminated(env::MaxTimeoutEnv) = (env.current_t > env.max_t) || is_terminated(env.env)
29+
RLBase.is_terminated(env::MaxTimeoutEnv) =
30+
(env.current_t > env.max_t) || is_terminated(env.env)
2931

3032

3133
RLBase.state(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
32-
RLBase.state_space(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)
34+
RLBase.state_space(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) =
35+
state_space(env.env, ss)

src/environments/wrappers/MultiThreadEnv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ end
122122

123123
function RLBase.legal_action_space_mask(env::MultiThreadEnv)
124124
@sync for i in 1:length(env)
125-
@spawn selectdim(env.legal_action_space_mask, N, i) .= legal_action_space_mask(env[i])
125+
@spawn selectdim(env.legal_action_space_mask, N, i) .=
126+
legal_action_space_mask(env[i])
126127
end
127128
env.legal_action_space_mask
128129
end
@@ -136,4 +137,4 @@ for f in RLBase.ENV_API
136137
if endswith(String(f), "Style")
137138
@eval RLBase.$f(x::MultiThreadEnv) = $f(x[1])
138139
end
139-
end
140+
end

0 commit comments

Comments
 (0)