@@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces, Episode, Episodes
3
3
import MacroTools: @forward
4
4
5
5
import CircularArrayBuffers
6
+ import Adapt
6
7
7
8
# ####
8
9
@@ -13,11 +14,23 @@ Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x
13
14
Base. summary (io:: IO , t:: AbstractTrace ) = print (io, " $(length (t)) -element $(nameof (typeof (t))) " )
14
15
15
16
# ####
17
+
18
+ """
19
+ Trace(A::AbstractArray)
20
+
21
+ Similar to
22
+ [`Slices`](https://github.com/JuliaLang/julia/blob/master/base/slicearray.jl)
23
+ which will be introduced in `[email protected] `. The main difference is that, the
24
+ `axes` info in the `Slices` is static, while it may be dynamic with `Trace`.
25
+
26
+ We only support slices along the last dimension since it's the most common usage
27
+ in RL.
28
+ """
16
29
struct Trace{T,E} <: AbstractTrace{E}
17
30
parent:: T
18
31
end
19
32
20
- Base. summary (io:: IO , t:: Trace{T} ) where {T} = print (io, " $(length (t)) -element $(nameof (typeof (t))) {$T }" )
33
+ Base. summary (io:: IO , t:: Trace{T} ) where {T} = print (io, " $(length (t)) -element$( length (t) > 0 ? ' s ' : " " ) $(nameof (typeof (t))) {$T }" )
21
34
22
35
function Trace (x:: T ) where {T<: AbstractArray }
23
36
E = eltype (x)
@@ -27,6 +40,8 @@ function Trace(x::T) where {T<:AbstractArray}
27
40
Trace {T,SubArray{E,N,P,I,true}} (x)
28
41
end
29
42
43
+ Adapt. adapt_structure (to, t:: Trace ) = Trace (Adapt. adapt_structure (to, t. parent))
44
+
30
45
Base. convert (:: Type{AbstractTrace} , x:: AbstractArray ) = Trace (x)
31
46
32
47
Base. size (x:: Trace ) = (size (x. parent, ndims (x. parent)),)
@@ -59,6 +74,21 @@ Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names
59
74
60
75
# ####
61
76
77
+ """
78
+ Dedicated for `MultiplexTraces` to avoid scalar indexing when `view(view(t::MultiplexTrace, 1:end-1), I)`.
79
+ """
80
+ struct RelativeTrace{left,right,T,E} <: AbstractTrace{E}
81
+ trace:: Trace{T,E}
82
+ end
83
+ RelativeTrace {left,right} (t:: Trace{T,E} ) where {left,right,T,E} = RelativeTrace {left,right,T,E} (t)
84
+
85
+ Base. size (x:: RelativeTrace{0,-1} ) = (max (0 , length (x. trace) - 1 ),)
86
+ Base. size (x:: RelativeTrace{1,0} ) = (max (0 , length (x. trace) - 1 ),)
87
+ Base. getindex (s:: RelativeTrace{0,-1} , I) = getindex (s. trace, I)
88
+ Base. getindex (s:: RelativeTrace{1,0} , I) = getindex (s. trace, I .+ 1 )
89
+ Base. setindex! (s:: RelativeTrace{0,-1} , v, I) = setindex! (s. trace, v, I)
90
+ Base. setindex! (s:: RelativeTrace{1,0} , v, I) = setindex! (s. trace, v, I .+ 1 )
91
+
62
92
"""
63
93
MultiplexTraces{names}(trace)
64
94
@@ -89,12 +119,14 @@ function MultiplexTraces{names}(t) where {names}
89
119
MultiplexTraces {names,typeof(trace),eltype(trace)} (trace)
90
120
end
91
121
122
+ Adapt. adapt_structure (to, t:: MultiplexTraces{names} ) where {names} = MultiplexTraces {names} (Adapt. adapt_structure (to, t. trace))
123
+
92
124
function Base. getindex (t:: MultiplexTraces{names} , k:: Symbol ) where {names}
93
125
a, b = names
94
126
if k == a
95
- convert (AbstractTrace, t. trace[ 1 : end - 1 ] )
127
+ RelativeTrace {0,-1} ( convert (AbstractTrace, t. trace) )
96
128
elseif k == b
97
- convert (AbstractTrace, t. trace[ 2 : end ] )
129
+ RelativeTrace {1,0} ( convert (AbstractTrace, t. trace) )
98
130
else
99
131
throw (ArgumentError (" unknown trace name: $k " ))
100
132
end
133
165
134
166
Episode (t:: AbstractTraces{names,T} ) where {names,T} = Episode {typeof(t),names,T} (t, Ref (false ))
135
167
168
+ Adapt. adapt_structure (to, t:: Episode{T,names,E} ) where {T,names,E} = Episode {T,names,E} (Adapt. adapt_structure (to, t. traces), t. is_terminated)
169
+
136
170
@forward Episode. traces Base. getindex, Base. setindex!, Base. size
137
171
138
172
Base. getindex (e:: Episode ) = getindex (e. is_terminated)
@@ -175,6 +209,11 @@ struct Episodes{names,E,T} <: AbstractTraces{names,E}
175
209
inds:: Vector{Tuple{Int,Int}}
176
210
end
177
211
212
+ Adapt. adapt_structure (to, t:: Episodes ) =
213
+ Episodes () do
214
+ Adapt. adapt_structure (to, t. init ())
215
+ end
216
+
178
217
function Episodes (init)
179
218
x = init ()
180
219
T = typeof (x)
@@ -249,6 +288,11 @@ struct Traces{names,T,N,E} <: AbstractTraces{names,E}
249
288
inds:: NamedTuple{names,NTuple{N,Int}}
250
289
end
251
290
291
+ function Adapt. adapt_structure (to, t:: Traces{names,T,N,E} ) where {names,T,N,E}
292
+ data = Adapt. adapt_structure (to, t. traces)
293
+ # FIXME : `E` is not adapted here
294
+ Traces {names,typeof(data),length(names),E} (data, t. inds)
295
+ end
252
296
253
297
function Traces (; kw... )
254
298
data = map (x -> convert (AbstractTrace, x), values (kw))
0 commit comments