1
1
using Random, POMDPs, POMDPModels, SparseArrays, LinearAlgebra, StatsBase
2
2
3
3
export MDPEnv, POMDPEnv, SimpleMDPEnv, absorbing_deterministic_tree_MDP, stochastic_MDP, stochastic_tree_MDP,
4
- deterministic_tree_MDP_with_rand_reward, deterministic_tree_MDP, deterministic_MDP
4
+ deterministic_tree_MDP_with_rand_reward, deterministic_tree_MDP, deterministic_MDP,
5
+ DeterministicStateActionReward, DeterministicNextStateReward,
6
+ NormalStateActionReward, NormalNextStateReward
5
7
6
8
# ####
7
9
# #### POMDPEnv
@@ -24,11 +26,11 @@ POMDPEnv(model; rng=Random.GLOBAL_RNG) = POMDPEnv(
24
26
DiscreteSpace (n_states (model)),
25
27
rng)
26
28
27
- function interact! (env:: POMDPEnv , action)
29
+ function interact! (env:: POMDPEnv , action)
28
30
s, o, r = generate_sor (env. model, env. state, env. actions[action], env. rng)
29
31
env. state = s
30
- (observation = observationindex (env. model, o),
31
- reward = r,
32
+ (observation = observationindex (env. model, o),
33
+ reward = r,
32
34
isdone = isterminal (env. model, s))
33
35
end
34
36
@@ -72,13 +74,13 @@ function interact!(env::MDPEnv, action)
72
74
s = rand (env. rng, transition (env. model, env. state, env. actions[action]))
73
75
r = reward (env. model, env. state, env. actions[action])
74
76
env. state = s
75
- (observation = stateindex (env. model, s),
76
- reward = r,
77
+ (observation = stateindex (env. model, s),
78
+ reward = r,
77
79
isdone = isterminal (env. model, s))
78
80
end
79
81
80
82
function observe (env:: MDPEnv )
81
- (observation = stateindex (env. model, env. state),
83
+ (observation = stateindex (env. model, env. state),
82
84
isdone = isterminal (env. model, env. state))
83
85
end
84
86
91
93
na::Int64
92
94
state::Int64
93
95
trans_probs::Array{AbstractArray, 2}
94
- reward::Array{Float64, 2}
96
+ reward::R
95
97
initialstates::Array{Int64, 1}
96
98
isterminal::Array{Int64, 1}
99
+ rng::S
97
100
A Markov Decision Process with `ns` states, `na` actions, current `state`,
98
101
`na`x`ns` - array of transition probabilites `trans_props` which consists for
99
102
every (action, state) pair of a (potentially sparse) array that sums to 1 (see
100
103
[`get_prob_vec_random`](@ref), [`get_prob_vec_uniform`](@ref),
101
104
[`get_prob_vec_deterministic`](@ref) for helpers to constract the transition
102
- probabilities) `na`x`ns` - array of `reward`, array of initial states
105
+ probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref),
106
+ [`DeterministicNextStateReward`](@ref), [`NormalNextStateReward`](@ref),
107
+ [`NormalStateActionReward`](@ref)), array of initial states
103
108
`initialstates`, and `ns` - array of 0/1 indicating if a state is terminal.
104
109
"""
105
- mutable struct SimpleMDPEnv{T}
110
+ mutable struct SimpleMDPEnv{T,R,S <: AbstractRNG }
106
111
observation_space:: DiscreteSpace
107
112
action_space:: DiscreteSpace
108
113
state:: Int64
109
114
trans_probs:: Array{T, 2}
110
- reward:: Array{Float64, 2}
115
+ reward:: R
111
116
initialstates:: Array{Int64, 1}
112
117
isterminal:: Array{Int64, 1}
118
+ rng:: S
113
119
end
114
120
115
121
function SimpleMDPEnv (ospace, aspace, state, trans_probs:: Array{T, 2} ,
116
- reward, initialstates, isterminal) where T
117
- SimpleMDPEnv {T} (ospace, aspace, state, trans_probs, reward, initialstates, isterminal)
122
+ reward:: R , initialstates, isterminal,
123
+ rng:: S = Random. GLOBAL_RNG) where {T,R,S}
124
+ if R <: AbstractMatrix # to ensure compatibility with previous versions
125
+ reward = DeterministicStateActionReward (reward)
126
+ end
127
+ SimpleMDPEnv {T,typeof(reward),S} (ospace, aspace, state, trans_probs,
128
+ reward, initialstates, isterminal, rng)
118
129
end
119
130
120
131
observation_space (env:: SimpleMDPEnv ) = env. observation_space
121
132
action_space (env:: SimpleMDPEnv ) = env. action_space
122
133
134
+ # reward types
135
+ """
136
+ struct DeterministicNextStateReward
137
+ value::Vector{Float64}
138
+ """
139
+ struct DeterministicNextStateReward
140
+ value:: Vector{Float64}
141
+ end
142
+ reward (:: AbstractRNG , r:: DeterministicNextStateReward , s, a, s′) = r. value[s′]
143
+ expected_rewards (r:: DeterministicNextStateReward , trans_probs) = expected_rewards (r. value, trans_probs)
144
+ function expected_rewards (r:: AbstractVector , trans_probs)
145
+ result = zeros (size (trans_probs))
146
+ for i in eachindex (trans_probs)
147
+ result[i] = dot (trans_probs[i], r)
148
+ end
149
+ result
150
+ end
151
+ """
152
+ struct DeterministicStateActionReward
153
+ value::Array{Float64, 2}
154
+
155
+ `value` should be a `na × ns`-matrix.
156
+ """
157
+ struct DeterministicStateActionReward
158
+ value:: Array{Float64, 2}
159
+ end
160
+ reward (:: AbstractRNG , r:: DeterministicStateActionReward , s, a, s′) = r. value[a, s]
161
+ expected_rewards (r:: DeterministicStateActionReward , :: Any ) = r. value
162
+ """
163
+ struct NormalNextStateReward
164
+ mean::Vector{Float64}
165
+ std::Vector{Float64}
166
+ """
167
+ struct NormalNextStateReward
168
+ mean:: Vector{Float64}
169
+ std:: Vector{Float64}
170
+ end
171
+ reward (rng, r:: NormalNextStateReward , s, a, s′) = r. mean[s′] + randn (rng) * r. std[s′]
172
+ expected_rewards (r:: NormalNextStateReward , trans_probs) = expected_rewards (r. mean, trans_probs)
173
+ """
174
+ struct NormalStateActionReward
175
+ mean::Array{Float64, 2}
176
+ std::Array{Float64, 2}
177
+
178
+ `mean` and `std` should be `na × ns`-matrices.
179
+ """
180
+ struct NormalStateActionReward
181
+ mean:: Array{Float64, 2}
182
+ std:: Array{Float64, 2}
183
+ end
184
+ reward (rng, r:: NormalStateActionReward , s, a, s′) = r. mean[a, s] + randn (rng) * r. std[a, s]
185
+ expected_rewards (r:: NormalStateActionReward , :: Any ) = r. mean
186
+
123
187
# run SimpleMDPEnv
124
188
"""
125
189
run!(mdp::SimpleMDPEnv, action::Int64)
@@ -129,7 +193,7 @@ function run!(mdp::SimpleMDPEnv, action::Int64)
129
193
if mdp. isterminal[mdp. state] == 1
130
194
reset! (mdp)
131
195
else
132
- mdp. state = wsample (mdp. trans_probs[action, mdp. state])
196
+ mdp. state = wsample (mdp. rng, mdp . trans_probs[action, mdp. state])
133
197
(observation = mdp. state,)
134
198
end
135
199
end
@@ -141,8 +205,9 @@ run!(mdp::SimpleMDPEnv, policy::Array{Int64, 1}) = run!(mdp, policy[mdp.state])
141
205
142
206
143
207
function interact! (env:: SimpleMDPEnv , action)
144
- r = env. reward[action, env . state]
208
+ oldstate = env. state
145
209
run! (env, action)
210
+ r = reward (env. rng, env. reward, oldstate, action, env. state)
146
211
(observation = env. state, reward = r, isdone = env. isterminal[env. state] == 1 )
147
212
end
148
213
@@ -151,7 +216,7 @@ function observe(env::SimpleMDPEnv)
151
216
end
152
217
153
218
function reset! (env:: SimpleMDPEnv )
154
- env. state = rand (env. initialstates)
219
+ env. state = rand (env. rng, env . initialstates)
155
220
nothing
156
221
end
157
222
@@ -274,7 +339,7 @@ function deterministic_MDP(; ns = 10^4, na = 10)
274
339
end
275
340
276
341
"""
277
- deterministic_tree_MDP(; na = 4, depth = 5)
342
+ deterministic_tree_MDP(; na = 4, depth = 5)
278
343
Returns a tree_MDP with random rewards at the leaf nodes.
279
344
"""
280
345
function deterministic_tree_MDP (; na = 4 , depth = 5 )
@@ -317,4 +382,4 @@ function absorbing_deterministic_tree_MDP(;ns = 10^3, na = 10)
317
382
reset! (mdp)
318
383
set_terminal_states! (mdp, ns - div (ns, 100 ) + 1 : ns)
319
384
mdp
320
- end
385
+ end
0 commit comments