@@ -272,25 +272,22 @@ end
272
272
273
273
using IntervalSets
274
274
275
- """
276
- watch https://github.com/JuliaMath/IntervalSets.jl/issues/66
277
- """
278
- function Base. in (x:: AbstractArray , s:: Array{<:Interval} )
279
- size (x) == size (s) && all (x .∈ s)
280
- end
281
-
282
275
Random. rand (s:: Union{Interval, Array{<:Interval}} ) = rand (Random. GLOBAL_RNG, s)
283
276
284
277
function Random. rand (rng:: AbstractRNG , s:: Interval )
285
278
rand (rng) * (s. right - s. left) + s. left
286
279
end
287
280
288
- function Random . rand (rng :: AbstractRNG , s :: Array{<:Interval} )
289
- map (x -> rand (rng, x), s)
290
- end
281
+ # ####
282
+ # WorldSpace
283
+ # ####
291
284
292
285
export WorldSpace
293
286
287
+ """
288
+ In some cases, we may not be interested in the action/state space.
289
+ One can return `WorldSpace()` to keep the interface consistent.
290
+ """
294
291
struct WorldSpace{T} end
295
292
296
293
WorldSpace () = WorldSpace {Any} ()
355
352
Random. rand (rng:: AbstractRNG , s:: AbstractVector{<:ActionProbPair} ) = s[weighted_sample (rng, (x. prob for x in s))]
356
353
357
354
(env:: AbstractEnv )(a:: ActionProbPair ) = env (a. action)
355
+
356
+ # ####
357
+ # Space
358
+ # ####
359
+
360
+ export Space
361
+
362
+ """
363
+ A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
364
+ """
365
+ struct Space{T}
366
+ s:: T
367
+ end
368
+
369
+ Random. rand (s:: Space ) = rand (Random. GLOBAL_RNG, s)
370
+
371
+ Random. rand (rng:: AbstractRNG , s:: Space ) = map (s. s) do x
372
+ rand (rng, x)
373
+ end
374
+
375
+ Random. rand (rng:: AbstractRNG , s:: Space{<:Dict} ) = Dict (k=> rand (rng,v) for (k,v) in s. s)
376
+
377
+ function Base. in (X, S:: Space )
378
+ if length (X) == length (S. s)
379
+ for (x,s) in zip (X, S. s)
380
+ if x ∉ s
381
+ return false
382
+ end
383
+ end
384
+ return true
385
+ else
386
+ return false
387
+ end
388
+ end
389
+
390
+ function Base. in (X:: Dict , S:: Space{<:Dict} )
391
+ if keys (X) == keys (S. s)
392
+ for k in keys (X)
393
+ if X[k] ∉ S. s[k]
394
+ return false
395
+ end
396
+ end
397
+ return true
398
+ else
399
+ return false
400
+ end
401
+ end
0 commit comments