Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Simplify code structure #107

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
name = "ReinforcementLearningBase"
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
version = "0.8.5"
version = "0.9.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractTrees = "0.3"
CommonRLInterface = "0.2"
MacroTools = "0.5"
julia = "1.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,43 @@

[![Build Status](https://travis-ci.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl)

ReinforcementLearningBase.jl holds the common types and utility functions to be shared by other components in ReinforcementLearning ecosystem.
ReinforcementLearningBase.jl holds the common types and utility functions to be
shared by other components in ReinforcementLearning ecosystem.


## Examples

<table>
<th colspan="2">Traits</th><th> 1 </th><th> 2 </th><th> 3 </th><th> 4 </th><th> 5 </th><th> 6 </th><th> 7 </th><th> 8 </th><th> 9 </th><tr> <th rowspan="2"> ActionStyle </th><th> MinimalActionSet </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> FullActionSet </th><td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="3"> ChanceStyle </th><th> Stochastic </th><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> Deterministic </th><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> </tr>
<tr> <th> ExplicitStochastic </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th rowspan="2"> DefaultStateStyle </th><th> Observation </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> InformationSet </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th rowspan="2"> DynamicStyle </th><th> Simultaneous </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> Sequential </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th rowspan="2"> InformationStyle </th><th> PerfectInformation </th><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> ImperfectInformation </th><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th rowspan="2"> NumAgentStyle </th><th> MultiAgent </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> SingleAgent </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="2"> RewardStyle </th><th> TerminalReward </th><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> StepReward </th><td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="3"> StateStyle </th><th> Observation </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> InformationSet </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th> InternalState </th><td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="4"> UtilityStyle </th><th> GeneralSum </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> ZeroSum </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> ✔ </td></tr>
<tr> <th> ConstantSum </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> IdenticalUtility </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> </tr>
</table>
<ol><li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/MultiArmBanditsEnv.jl"> MultiArmBanditsEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/RandomWalk1D.jl"> RandomWalk1D </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TigerProblemEnv.jl"> TigerProblemEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/MontyHallEnv.jl"> MontyHallEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/RockPaperScissorsEnv.jl"> RockPaperScissorsEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TicTacToeEnv.jl"> TicTacToeEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TinyHanabiEnv.jl"> TinyHanabiEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/PigEnv.jl"> PigEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/KuhnPokerEnv.jl"> KuhnPokerEnv </a></li>
</ol>
32 changes: 16 additions & 16 deletions src/CommonRLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ end
const CommonRLEnvs = Union{CommonRLEnv,CommonRLMarkovEnv,CommonRLZeroSumEnv}

function Base.convert(::Type{CRL.AbstractEnv}, env::AbstractEnv)
if get_num_players(env) == 1
if NumAgentStyle(env) === SINGLE_AGENT
convert(CRL.AbstractMarkovEnv, env)
elseif get_num_players(env) == 2 && UtilityStyle(env) === ZERO_SUM
elseif NumAgentStyle(env) isa MultiAgent{2} && UtilityStyle(env) === ZERO_SUM
convert(CRL.AbstractZeroSumEnv, env)
else
CommonRLEnv(env)
Expand All @@ -34,25 +34,25 @@ Base.convert(::Type{CRL.AbstractMarkovEnv}, env::AbstractEnv) = CommonRLMarkovEn
Base.convert(::Type{CRL.AbstractZeroSumEnv}, env::AbstractEnv) = CommonRLZeroSumEnv(env)

CRL.@provide CRL.reset!(env::CommonRLEnvs) = reset!(env.env)
CRL.@provide CRL.actions(env::CommonRLEnvs) = get_actions(env.env)
CRL.@provide CRL.observe(env::CommonRLEnvs) = get_state(env.env)
CRL.state(env::CommonRLEnvs) = get_state(env.env)
CRL.@provide CRL.actions(env::CommonRLEnvs) = action_space(env.env)
CRL.@provide CRL.observe(env::CommonRLEnvs) = state(env.env)
CRL.state(env::CommonRLEnvs) = state(env.env)
CRL.provided(::typeof(CRL.state), env::CommonRLEnvs) =
InformationStyle(env.env) === PERFECT_INFORMATION
CRL.@provide CRL.terminated(env::CommonRLEnvs) = get_terminal(env.env)
CRL.@provide CRL.player(env::CommonRLEnvs) = get_current_player(env.env)
CRL.@provide CRL.terminated(env::CommonRLEnvs) = is_terminated(env.env)
CRL.@provide CRL.player(env::CommonRLEnvs) = current_player(env.env)
CRL.@provide CRL.clone(env::CommonRLEnvs) = CommonRLEnv(copy(env.env))

CRL.@provide function CRL.act!(env::CommonRLEnvs, a)
env.env(a)
get_reward(env.env)
reward(env.env)
end

CRL.valid_actions(x::CommonRLEnvs) = get_legal_actions(x.env)
CRL.valid_actions(x::CommonRLEnvs) = legal_action_space(x.env)
CRL.provided(::typeof(CRL.valid_actions), env::CommonRLEnvs) =
ActionStyle(env.env) === FullActionSet()

CRL.valid_action_mask(x::CommonRLEnvs) = get_legal_actions_mask(x.env)
CRL.valid_action_mask(x::CommonRLEnvs) = legal_action_space_mask(x.env)
CRL.provided(::typeof(CRL.valid_action_mask), env::CommonRLEnvs) =
ActionStyle(env.env) === FullActionSet()

Expand All @@ -68,12 +68,12 @@ end
Base.convert(::Type{AbstractEnv}, env::CRL.AbstractEnv) = convert(RLBaseEnv, env)
Base.convert(::Type{RLBaseEnv}, env::CRL.AbstractEnv) = RLBaseEnv(env, 0.0f0) # can not determine reward ahead. Assume `Float32`.

get_state(env::RLBaseEnv) = CRL.observe(env.env)
get_actions(env::RLBaseEnv) = CRL.actions(env.env)
get_reward(env::RLBaseEnv) = env.r
get_terminal(env::RLBaseEnv) = CRL.terminated(env.env)
get_legal_actions(env::RLBaseEnv) = CRL.valid_actions(env.env)
get_legal_actions_mask(env::RLBaseEnv) = CRL.valid_action_mask(env.env)
state(env::RLBaseEnv) = CRL.observe(env.env)
action_space(env::RLBaseEnv) = CRL.actions(env.env)
reward(env::RLBaseEnv) = env.r
is_terminated(env::RLBaseEnv) = CRL.terminated(env.env)
legal_action_space(env::RLBaseEnv) = CRL.valid_actions(env.env)
legal_action_space_mask(env::RLBaseEnv) = CRL.valid_action_mask(env.env)
reset!(env::RLBaseEnv) = CRL.reset!(env.env)

(env::RLBaseEnv)(a) = env.r = CRL.act!(env.env, a)
Expand Down
6 changes: 4 additions & 2 deletions src/ReinforcementLearningBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ module ReinforcementLearningBase
const RLBase = ReinforcementLearningBase
export RLBase

using Random

include("inline_export.jl")
include("interface.jl")
include("implementations/implementations.jl")
include("base.jl")
include("CommonRLInterface.jl")
include("base.jl")
include("examples/examples.jl")

end # module
Loading