Skip to content
This repository was archived by the owner on Aug 11, 2023. It is now read-only.

Commit 6aa6ee0

Browse files
authored
move Space from RLEnvs to RLBase (#119)
1 parent e4f9ff9 commit 6aa6ee0

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

src/base.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,73 @@
1+
#####
2+
# Spaces
3+
#####
4+
5+
export WorldSpace
6+
7+
"""
8+
In some cases, we may not be interested in the action/state space.
9+
One can return `WorldSpace()` to keep the interface consistent.
10+
"""
11+
struct WorldSpace{T} end
12+
13+
WorldSpace() = WorldSpace{Any}()
14+
15+
Base.in(x, ::WorldSpace{T}) where {T} = x isa T
16+
17+
export Space
18+
19+
"""
20+
A wrapper to treat each element as a sub-space which supports:
21+
22+
- `Base.in`
23+
- `Random.rand`
24+
"""
25+
struct Space{T}
26+
s::T
27+
end
28+
29+
Base.similar(s::Space, args...) = Space(similar(s.s, args...))
30+
Base.getindex(s::Space, args...) = getindex(s.s, args...)
31+
Base.setindex!(s::Space, args...) = setindex!(s.s, args...)
32+
Base.size(s::Space) = size(s.s)
33+
Base.length(s::Space) = length(s.s)
34+
Base.iterate(s::Space, args...) = iterate(s.s, args...)
35+
36+
Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)
37+
38+
Random.rand(rng::AbstractRNG, s::Space) =
39+
map(s.s) do x
40+
rand(rng, x)
41+
end
42+
43+
Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k => rand(rng, v) for (k, v) in s.s)
44+
45+
function Base.in(X, S::Space)
46+
if length(X) == length(S.s)
47+
for (x, s) in zip(X, S.s)
48+
if x s
49+
return false
50+
end
51+
end
52+
return true
53+
else
54+
return false
55+
end
56+
end
57+
58+
function Base.in(X::Dict, S::Space{<:Dict})
59+
if keys(X) == keys(S.s)
60+
for k in keys(X)
61+
if X[k] S.s[k]
62+
return false
63+
end
64+
end
65+
return true
66+
else
67+
return false
68+
end
69+
end
70+
171
#####
272
# printing
373
#####

0 commit comments

Comments
 (0)