Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Commit 2002b56

Browse files
committed
add a general Space
1 parent 73a0c70 commit 2002b56

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

src/base.jl

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,25 +272,22 @@ end
272272

273273
using IntervalSets
274274

275-
"""
276-
watch https://github.com/JuliaMath/IntervalSets.jl/issues/66
277-
"""
278-
function Base.in(x::AbstractArray, s::Array{<:Interval})
279-
size(x) == size(s) && all(x .∈ s)
280-
end
281-
282275
Random.rand(s::Union{Interval, Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)
283276

284277
function Random.rand(rng::AbstractRNG, s::Interval)
285278
rand(rng) * (s.right - s.left) + s.left
286279
end
287280

288-
function Random.rand(rng::AbstractRNG, s::Array{<:Interval})
289-
map(x -> rand(rng, x), s)
290-
end
281+
#####
282+
# WorldSpace
283+
#####
291284

292285
export WorldSpace
293286

287+
"""
288+
In some cases, we may not be interested in the action/state space.
289+
One can return `WorldSpace()` to keep the interface consistent.
290+
"""
294291
struct WorldSpace{T} end
295292

296293
WorldSpace() = WorldSpace{Any}()
@@ -355,3 +352,50 @@ end
355352
Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) = s[weighted_sample(rng, (x.prob for x in s))]
356353

357354
(env::AbstractEnv)(a::ActionProbPair) = env(a.action)
355+
356+
#####
357+
# Space
358+
#####
359+
360+
export Space
361+
362+
"""
363+
A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
364+
"""
365+
struct Space{T}
366+
s::T
367+
end
368+
369+
Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)
370+
371+
Random.rand(rng::AbstractRNG, s::Space) = map(s.s) do x
372+
rand(rng, x)
373+
end
374+
375+
Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k=>rand(rng,v) for (k,v) in s.s)
376+
377+
function Base.in(X, S::Space)
378+
if length(X) == length(S.s)
379+
for (x,s) in zip(X, S.s)
380+
if x s
381+
return false
382+
end
383+
end
384+
return true
385+
else
386+
return false
387+
end
388+
end
389+
390+
function Base.in(X::Dict, S::Space{<:Dict})
391+
if keys(X) == keys(S.s)
392+
for k in keys(X)
393+
if X[k] S.s[k]
394+
return false
395+
end
396+
end
397+
return true
398+
else
399+
return false
400+
end
401+
end

0 commit comments

Comments
 (0)