Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 0ca9249

Browse files
authored
Dev (#18)
* make DiscreteSpace more flexible * format * add VBasedPolicy * add policies * set compat of Julia to v1.3 * drop ReinforcementLearningEnvironments.jl in deps * update compat
1 parent 67f7197 commit 0ca9249

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+612
-510
lines changed

Project.toml

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,23 @@ authors = ["Jun Tian <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7-
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
87
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
108
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
119
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1210
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13-
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1411
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1512
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1613

1714
[compat]
18-
Reexport = "0.2"
19-
StatsBase = "0.32"
20-
MacroTools = "0.5"
21-
ProgressMeter = "1.2"
2215
Distributions = "0.22"
23-
ReinforcementLearningBase = "0.5"
24-
CuArrays = "1.7"
25-
Flux = "0.10"
26-
julia = "1"
16+
ProgressMeter = "1.2"
17+
ReinforcementLearningBase = "0.6"
18+
StatsBase = "0.32"
19+
julia = "1.3"
2720

2821
[extras]
22+
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
2923
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3024

3125
[targets]
32-
test = ["Test"]
26+
test = ["Test", "ReinforcementLearningEnvironments"]

src/ReinforcementLearningCore.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
module ReinforcementLearningCore
22

3-
using Reexport
3+
using ReinforcementLearningBase
44

55
const RLCore = ReinforcementLearningCore
66
export RLCore
77

8-
@reexport using ReinforcementLearningBase
9-
108
include("utils/utils.jl")
119
include("core/core.jl")
1210
include("components/components.jl")

src/components/agents/agent.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: A
2323
role::R = DEFAULT_PLAYER
2424
end
2525

26+
RLBase.get_role(agent::Agent) = agent.role
27+
2628
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
2729
::PreEpisodeStage,
2830
obs,
@@ -35,9 +37,9 @@ function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
3537
::PreActStage,
3638
obs,
3739
)
38-
update!(agent.policy, agent.trajectory)
3940
action = agent.policy(obs)
4041
push!(agent.trajectory; state = get_state(obs), action = action)
42+
update!(agent.policy, agent.trajectory)
4143
action
4244
end
4345

@@ -55,6 +57,7 @@ function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
5557
)
5658
action = agent.policy(obs)
5759
push!(agent.trajectory; state = get_state(obs), action = action)
60+
update!(agent.policy, agent.trajectory)
5861
action
5962
end
6063

@@ -72,9 +75,9 @@ function (agent::Agent{<:AbstractPolicy,<:CircularCompactSARTSATrajectory})(
7275
::PreActStage,
7376
obs,
7477
)
75-
update!(agent.policy, agent.trajectory)
7678
action = agent.policy(obs)
7779
push!(agent.trajectory; state = get_state(obs), action = action)
80+
update!(agent.policy, agent.trajectory)
7881
action
7982
end
8083

@@ -92,6 +95,7 @@ function (agent::Agent{<:AbstractPolicy,<:CircularCompactSARTSATrajectory})(
9295
)
9396
action = agent.policy(obs)
9497
push!(agent.trajectory; state = get_state(obs), action = action)
98+
update!(agent.policy, agent.trajectory)
9599
action
96100
end
97101

@@ -109,9 +113,9 @@ function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
109113
::PreActStage,
110114
obs,
111115
)
112-
update!(agent.policy, agent.trajectory)
113116
action = agent.policy(obs)
114117
push!(agent.trajectory; state = get_state(obs), action = action)
118+
update!(agent.policy, agent.trajectory)
115119
action
116120
end
117121

@@ -129,5 +133,6 @@ function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
129133
)
130134
action = agent.policy(obs)
131135
push!(agent.trajectory; state = get_state(obs), action = action)
136+
update!(agent.policy, agent.trajectory)
132137
action
133138
end

src/components/agents/agents.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
include("agent.jl")
2+
include("dyna_agent.jl")

src/components/agents/dyna_agent.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
export DynaAgent
2+
3+
Base.@kwdef struct DynaAgent{P<:AbstractPolicy, B<:AbstractTrajectory, M<:AbstractEnvironmentModel, R} <: AbstractAgent
4+
policy::P
5+
model::M
6+
trajectory::B
7+
role::R = DEFAULT_PLAYER
8+
plan_step::Int = 10
9+
end
10+
11+
RLBase.get_role(agent::DynaAgent) = agent.role
12+
13+
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
14+
::PreEpisodeStage,
15+
obs,
16+
)
17+
empty!(agent.trajectory)
18+
nothing
19+
end
20+
21+
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
22+
::PreActStage,
23+
obs,
24+
)
25+
action = agent.policy(obs)
26+
push!(agent.trajectory; state = get_state(obs), action = action)
27+
update!(agent.model, agent.trajectory, agent.policy) # model learning
28+
update!(agent.policy, agent.trajectory) # direct learning
29+
update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning
30+
action
31+
end
32+
33+
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
34+
::PostActStage,
35+
obs,
36+
)
37+
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
38+
nothing
39+
end
40+
41+
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
42+
::PostEpisodeStage,
43+
obs,
44+
)
45+
action = agent.policy(obs)
46+
push!(agent.trajectory; state = get_state(obs), action = action)
47+
update!(agent.model, agent.trajectory, agent.policy) # model learning
48+
update!(agent.policy, agent.trajectory) # direct learning
49+
update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning
50+
action
51+
end
52+
53+
"By default, only use trajectory to update model"
54+
RLBase.update!(model::AbstractEnvironmentModel, t::AbstractTrajectory, π::AbstractPolicy) =
55+
update!(model, t)
56+
57+
function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory)
58+
transitions = extract_experience(buffer, model)
59+
isnothing(transitions) || update!(model, transitions)
60+
end
Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,48 @@
11
export TabularApproximator
22

33
"""
4-
TabularApproximator(table::Vector{Float64}) -> TabularApproximator
5-
TabularApproximator(;n_state::Int, init::Float64=0.0) -> TabularApproximator
4+
TabularApproximator(table<:AbstractArray)
65
7-
Use a `table` of type `Vector{Float64}` of length `ns` to record the state values.
6+
For `table` of 1-d, it will create a [`VApproximator`](@ref). For `table` of 2-d, it will create a [`QApproximator`].
7+
8+
!!! warning
9+
For `table` of 2-d, the first dimension is action and the second dimension is state.
810
"""
9-
struct TabularApproximator <: AbstractApproximator
10-
table::Vector{Float64}
11+
struct TabularApproximator{N, T<:AbstractArray} <: AbstractApproximator
12+
table::T
13+
function TabularApproximator(table::T) where {T<:AbstractArray}
14+
n = ndims(table)
15+
n <= 2 || throw(ArgumentError("the dimention of table must be <= 2"))
16+
new{n,T}(table)
17+
end
18+
end
19+
20+
function TabularApproximator(;n_state, n_action=nothing, init=0.)
21+
table = isnothing(n_action) ? fill(init, n_state) : fill(init, n_action, n_state)
22+
TabularApproximator(table)
1123
end
1224

13-
TabularApproximator(; n_state::Int, init::Float64 = 0.0) =
14-
TabularApproximator(fill(init, n_state))
25+
(app::TabularApproximator{1})(s::Int) = @views app.table[s]
1526

16-
(v::TabularApproximator)(s::Int) = v.table[s]
27+
(app::TabularApproximator{2})(s::Int) = @views app.table[:, s]
28+
(app::TabularApproximator{2})(s::Int, a::Int) = app(s)[a]
1729

18-
function RLBase.update!(v::TabularApproximator, correction::Pair{Int,Float64})
30+
function RLBase.update!(app::TabularApproximator{1}, correction::Pair)
1931
s, e = correction
20-
v.table[s] += e
32+
app.table[s] += e
33+
end
34+
35+
function RLBase.update!(app::TabularApproximator{2}, correction::Pair)
36+
(s, a), e = correction
37+
app.table[a, s] += e
38+
end
39+
40+
function RLBase.update!(Q::TabularApproximator{2}, correction::Pair{Int,Vector{Float64}})
41+
s, errors = correction
42+
for (a, e) in enumerate(errors)
43+
Q.table[a, s] += e
44+
end
2145
end
46+
47+
RLBase.ApproximatorStyle(::TabularApproximator{1}) = VApproximator()
48+
RLBase.ApproximatorStyle(::TabularApproximator{2}) = QApproximator()

src/components/components.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
include("spaces/spaces.jl")
2-
include("environments/environments.jl")
3-
include("policies/random_policy.jl")
1+
include("learners/learners.jl")
2+
include("policies/policies.jl")
43
include("approximators/approximators.jl")
54
include("explorers/explorers.jl")
65
include("trajectories/trajectories.jl")
76
include("preprocessors.jl")
8-
include("agents/agent.jl")
7+
include("agents/agents.jl")

src/components/environments/cartpole.jl

Lines changed: 0 additions & 113 deletions
This file was deleted.

src/components/environments/environments.jl

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/components/environments/wrapped_env.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)