72
72
73
73
function interact! (env:: MDPEnv , action)
74
74
s = rand (env. rng, transition (env. model, env. state, env. actions[action]))
75
- r = reward (env. model, env. state, env. actions[action])
75
+ r = POMDPs . reward (env. model, env. state, env. actions[action])
76
76
env. state = s
77
77
(observation = stateindex (env. model, s),
78
78
reward = r,
@@ -141,6 +141,7 @@ struct DeterministicNextStateReward
141
141
end
142
142
reward (:: AbstractRNG , r:: DeterministicNextStateReward , s, a, s′) = r. value[s′]
143
143
expected_rewards (r:: DeterministicNextStateReward , trans_probs) = expected_rewards (r. value, trans_probs)
144
+
144
145
function expected_rewards (r:: AbstractVector , trans_probs)
145
146
result = zeros (size (trans_probs))
146
147
for i in eachindex (trans_probs)
@@ -321,7 +322,11 @@ the same value.
321
322
function set_terminal_states! (mdp, range)
322
323
mdp. isterminal[range] .= 1
323
324
for s in findall (x -> x == 1 , mdp. isterminal)
324
- mdp. reward[:, s] .= mean (mdp. reward[:, s])
325
+ if mdp. reward isa DeterministicStateActionReward
326
+ mdp. reward. value[:, s] .= mean (mdp. reward. value[:, s])
327
+ else
328
+ mdp. reward[:, s] .= mean (mdp. reward[:, s])
329
+ end
325
330
for a in 1 : length (mdp. action_space)
326
331
empty_trans_prob! (mdp. trans_probs[a, s])
327
332
end
@@ -334,7 +339,11 @@ Returns a random deterministic SimpleMDPEnv.
334
339
"""
335
340
function deterministic_MDP (; ns = 10 ^ 4 , na = 10 )
336
341
mdp = SimpleMDPEnv (ns, na, init = " deterministic" )
337
- mdp. reward = mdp. reward .* (mdp. reward .< - 1.5 )
342
+ if mdp. reward isa DeterministicStateActionReward
343
+ mdp. reward. value .*= mdp. reward. value .< - 1.5
344
+ else
345
+ mdp. reward = mdp. reward .* (mdp. reward .< - 1.5 )
346
+ end
338
347
mdp
339
348
end
340
349
@@ -353,7 +362,11 @@ Returns a tree_MDP with random rewards.
353
362
function deterministic_tree_MDP_with_rand_reward (; args... )
354
363
mdp = deterministic_tree_MDP (; args... )
355
364
nonterminals = findall (x -> x == 0 , mdp. isterminal)
356
- mdp. reward[:, nonterminals] = - rand (length (mdp. action_space), length (nonterminals))
365
+ if mdp. reward isa DeterministicStateActionReward
366
+ mdp. reward. value[:, nonterminals] = - rand (length (mdp. action_space), length (nonterminals))
367
+ else
368
+ mdp. reward[:, nonterminals] = - rand (length (mdp. action_space), length (nonterminals))
369
+ end
357
370
mdp
358
371
end
359
372
@@ -377,7 +390,11 @@ Returns a random deterministic absorbing SimpleMDPEnv
377
390
"""
378
391
function absorbing_deterministic_tree_MDP (;ns = 10 ^ 3 , na = 10 )
379
392
mdp = SimpleMDPEnv (ns, na, init = " deterministic" )
380
- mdp. reward .= mdp. reward .* (mdp. reward .< - .5 )
393
+ if mdp. reward isa DeterministicStateActionReward
394
+ mdp. reward. value .*= mdp. reward. value .< - .5
395
+ else
396
+ mdp. reward .= mdp. reward .* (mdp. reward .< - .5 )
397
+ end
381
398
mdp. initialstates = 1 : div (ns, 100 )
382
399
reset! (mdp)
383
400
set_terminal_states! (mdp, ns - div (ns, 100 ) + 1 : ns)
0 commit comments