Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Automatic JuliaFormatter.jl run #111

Merged
merged 1 commit into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function test_interfaces!(env)
@testset "SingleAgent" begin
if NumAgentStyle(env) === SINGLE_AGENT
reset!(env)
total_reward = 0.
total_reward = 0.0
while !is_terminated(env)
if StateStyle(env) isa Tuple
for ss in StateStyle(env)
Expand Down Expand Up @@ -272,7 +272,7 @@ end

using IntervalSets

Random.rand(s::Union{Interval, Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)
Random.rand(s::Union{Interval,Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)

function Random.rand(rng::AbstractRNG, s::Interval)
rand(rng) * (s.right - s.left) + s.left
Expand All @@ -292,7 +292,7 @@ struct WorldSpace{T} end

WorldSpace() = WorldSpace{Any}()

Base.in(x, ::WorldSpace{T}) where T = x isa T
Base.in(x, ::WorldSpace{T}) where {T} = x isa T

#####
# ZeroTo
Expand All @@ -305,16 +305,16 @@ Similar to `Base.OneTo`. Useful when wrapping third-party environments.
"""
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T)-one(T),n))
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T) - one(T), n))
end

ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)

Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
Base.length(r::ZeroTo{T}) where T = T(r.stop + one(r.stop))
Base.first(r::ZeroTo{T}) where T = zero(r.stop)
Base.length(r::ZeroTo{T}) where {T} = T(r.stop + one(r.stop))
Base.first(r::ZeroTo{T}) where {T} = zero(r.stop)

function getindex(v::ZeroTo{T}, i::Integer) where T
function getindex(v::ZeroTo{T}, i::Integer) where {T}
Base.@_inline_meta
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
Expand Down Expand Up @@ -349,7 +349,8 @@ function weighted_sample(rng::AbstractRNG, wv)
end
end

Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) = s[weighted_sample(rng, (x.prob for x in s))]
Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) =
s[weighted_sample(rng, (x.prob for x in s))]

(env::AbstractEnv)(a::ActionProbPair) = env(a.action)

Expand All @@ -368,15 +369,16 @@ end

Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)

Random.rand(rng::AbstractRNG, s::Space) = map(s.s) do x
rand(rng, x)
end
Random.rand(rng::AbstractRNG, s::Space) =
map(s.s) do x
rand(rng, x)
end

Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k=>rand(rng,v) for (k,v) in s.s)
Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k => rand(rng, v) for (k, v) in s.s)

function Base.in(X, S::Space)
if length(X) == length(S.s)
for (x,s) in zip(X, S.s)
for (x, s) in zip(X, S.s)
if x ∉ s
return false
end
Expand All @@ -398,4 +400,4 @@ function Base.in(X::Dict, S::Space{<:Dict})
else
return false
end
end
end
3 changes: 2 additions & 1 deletion src/examples/TicTacToeEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ current_player(env::TicTacToeEnv) = env.player
players(env::TicTacToeEnv) = (CROSS, NOUGHT)

state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = Space(fill(false..true, 3, 3, 3))
state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) =
Space(fill(false..true, 3, 3, 3))
state(env::TicTacToeEnv, ::Observation{Int}, p) = get_tic_tac_toe_state_info()[env].index
state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
Base.OneTo(length(get_tic_tac_toe_state_info()))
Expand Down
2 changes: 1 addition & 1 deletion test/examples/kuhn_poker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

env = KuhnPokerEnv()

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

end
2 changes: 1 addition & 1 deletion test/examples/monty_hall_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
rng = StableRNG(123)
env = MontyHallEnv(; rng = rng)

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

n_win_car = 0
N = 50_000
Expand Down
2 changes: 1 addition & 1 deletion test/examples/multi_arm_bandits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
env = MultiArmBanditsEnv(; rng = rng)
rewards = []

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

N = 50_000
for _ in 1:N
Expand Down
2 changes: 1 addition & 1 deletion test/examples/random_walk_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
end_rewards = 3 => 5
env = RandomWalk1D(; rewards = end_rewards)

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

rng = StableRNG(123)
N = 50_000
Expand Down
2 changes: 1 addition & 1 deletion test/examples/rock_paper_scissors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
rng = StableRNG(123)
env = RockPaperScissorsEnv()

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

rewards = [[], []]
for _ in 1:50_000
Expand Down
2 changes: 1 addition & 1 deletion test/examples/tic_tac_toe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

env = TicTacToeEnv()

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

@test length(state_space(env, Observation{Int}())) == 5478

Expand Down
2 changes: 1 addition & 1 deletion test/examples/tiger_problem_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
obs_prob = 0.85
env = TigerProblemEnv(; rng = rng, obs_prob = obs_prob)

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

rewards = []
for _ in 1:50_000
Expand Down
2 changes: 1 addition & 1 deletion test/examples/tiny_hanabi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

env = TinyHanabiEnv()

RLBase.test_interfaces!(env)
RLBase.test_interfaces!(env)

end