Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Simplify code structure #112

Merged
merged 4 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ version = "0.9.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractTrees = "0.3"
CommonRLInterface = "0.2"
IntervalSets = "0.5"
MacroTools = "0.5"
julia = "1.3"
49 changes: 9 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,12 @@

[![Build Status](https://travis-ci.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl)

ReinforcementLearningBase.jl holds the common types and utility functions to be
shared by other components in ReinforcementLearning ecosystem.


## Examples

<table>
<th colspan="2">Traits</th><th> 1 </th><th> 2 </th><th> 3 </th><th> 4 </th><th> 5 </th><th> 6 </th><th> 7 </th><th> 8 </th><th> 9 </th><tr> <th rowspan="2"> ActionStyle </th><th> MinimalActionSet </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> FullActionSet </th><td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="3"> ChanceStyle </th><th> Stochastic </th><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> Deterministic </th><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> </tr>
<tr> <th> ExplicitStochastic </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th rowspan="2"> DefaultStateStyle </th><th> Observation </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> InformationSet </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th rowspan="2"> DynamicStyle </th><th> Simultaneous </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> Sequential </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th rowspan="2"> InformationStyle </th><th> PerfectInformation </th><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> ImperfectInformation </th><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th rowspan="2"> NumAgentStyle </th><th> MultiAgent </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> SingleAgent </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="2"> RewardStyle </th><th> TerminalReward </th><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> StepReward </th><td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="3"> StateStyle </th><th> Observation </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> InformationSet </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th> InternalState </th><td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="4"> UtilityStyle </th><th> GeneralSum </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> ZeroSum </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> ✔ </td></tr>
<tr> <th> ConstantSum </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> IdenticalUtility </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> </tr>
</table>
<ol><li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/MultiArmBanditsEnv.jl"> MultiArmBanditsEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/RandomWalk1D.jl"> RandomWalk1D </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TigerProblemEnv.jl"> TigerProblemEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/MontyHallEnv.jl"> MontyHallEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/RockPaperScissorsEnv.jl"> RockPaperScissorsEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TicTacToeEnv.jl"> TicTacToeEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TinyHanabiEnv.jl"> TinyHanabiEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/PigEnv.jl"> PigEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/KuhnPokerEnv.jl"> KuhnPokerEnv </a></li>
</ol>
This package defines two core concepts in reinforcement learning:

- `AbstractEnv`.
- Checkout
[ReinforcementLearningEnvironments.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl)
for versatile varieties of environments.
- `AbstractPolicy`.
[ReinforcementLearningCore.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningCore.jl)
is a good start point for how to write customized policies.
1 change: 0 additions & 1 deletion src/ReinforcementLearningBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ include("inline_export.jl")
include("interface.jl")
include("CommonRLInterface.jl")
include("base.jl")
include("examples/examples.jl")

end # module
221 changes: 15 additions & 206 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,213 +191,22 @@ function test_interfaces!(env)
reset!(env)
end

#####
# Generate README
#####

gen_traits_table(envs) = gen_traits_table(stdout, envs)

function gen_traits_table(io, envs)
trait_dict = Dict()
for f in env_traits()
for env in envs
if !haskey(trait_dict, f)
trait_dict[f] = Set()
end
t = f(env)
if f == StateStyle
if t isa Tuple
for x in t
push!(trait_dict[f], nameof(typeof(x)))
end
else
push!(trait_dict[f], nameof(typeof(t)))
end
else
push!(trait_dict[f], nameof(typeof(t)))
end
end
end

println(io, "<table>")

print(io, "<th colspan=\"2\">Traits</th>")
for i in 1:length(envs)
print(io, "<th> $(i) </th>")
end

for k in sort(collect(keys(trait_dict)), by = nameof)
vs = trait_dict[k]
print(io, "<tr> <th rowspan=\"$(length(vs))\"> $(nameof(k)) </th>")
for (i, v) in enumerate(vs)
if i != 1
print(io, "<tr> ")
end
print(io, "<th> $(v) </th>")
for env in envs
if k == StateStyle && k(env) isa Tuple
ss = k(env)
if v in map(x -> nameof(typeof(x)), ss)
print(io, "<td> ✔ </td>")
else
print(io, "<td> </td> ")
end
else
if nameof(typeof(k(env))) == v
print(io, "<td> ✔ </td>")
else
print(io, "<td> </td> ")
end
end
end
println(io, "</tr>")
end
end

println(io, "</table>")

print(io, "<ol>")
for env in envs
println(
io,
"<li> <a href=\"https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/$(nameof(env)).jl\"> $(nameof(env)) </a></li>",
)
end
print(io, "</ol>")
end

#####
# Utils
#####

using IntervalSets

Random.rand(s::Union{Interval,Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)

function Random.rand(rng::AbstractRNG, s::Interval)
rand(rng) * (s.right - s.left) + s.left
end

#####
# WorldSpace
#####

export WorldSpace

"""
In some cases, we may not be interested in the action/state space.
One can return `WorldSpace()` to keep the interface consistent.
"""
struct WorldSpace{T} end

WorldSpace() = WorldSpace{Any}()

Base.in(x, ::WorldSpace{T}) where {T} = x isa T

#####
# ZeroTo
#####

export ZeroTo

"""
Similar to `Base.OneTo`. Useful when wrapping third-party environments.
"""
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T) - one(T), n))
end

ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)

Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
Base.length(r::ZeroTo{T}) where {T} = T(r.stop + one(r.stop))
Base.first(r::ZeroTo{T}) where {T} = zero(r.stop)

function getindex(v::ZeroTo{T}, i::Integer) where {T}
Base.@_inline_meta
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

#####
# ActionProbPair
#####

export ActionProbPair

"""
Used in action space of chance player.
"""
struct ActionProbPair{A,P}
action::A
prob::P
end

"""
Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
Here we assume `wv` sum to `1`
"""
function weighted_sample(rng::AbstractRNG, wv)
t = rand(rng)
cw = zero(Base.first(wv))
for (i, w) in enumerate(wv)
cw += w
if cw >= t
return i
end
end
end

Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) =
s[weighted_sample(rng, (x.prob for x in s))]

(env::AbstractEnv)(a::ActionProbPair) = env(a.action)

#####
# Space
#####

export Space

"""
A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
"""
struct Space{T}
s::T
end

Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)

Random.rand(rng::AbstractRNG, s::Space) =
map(s.s) do x
rand(rng, x)
end

Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k => rand(rng, v) for (k, v) in s.s)

function Base.in(X, S::Space)
if length(X) == length(S.s)
for (x, s) in zip(X, S.s)
if x ∉ s
return false
end
end
return true
else
return false
end
end

function Base.in(X::Dict, S::Space{<:Dict})
if keys(X) == keys(S.s)
for k in keys(X)
if X[k] ∉ S.s[k]
return false
function test_runnable!(env, n = 1000;rng=Random.GLOBAL_RNG)
@testset "random policy with $(nameof(env))" begin
reset!(env)
for _ in 1:n
A = legal_action_space(env)
a = rand(rng, A)
@test a in A

S = state_space(env)
s = state(env)
@test s in S
env(a)
if is_terminated(env)
reset!(env)
end
end
return true
else
return false
reset!(env)
end
end
Loading