Skip to content

Commit 9808604

Browse files
committed
add patch code for StackViews
1 parent 016bf93 commit 9808604

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

src/patch.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
import MLUtils
22

3-
MLUtils.batch(x::AbstractArray{<:Number}) = x
3+
MLUtils.batch(x::AbstractArray{<:Number}) = x
4+
5+
#####
6+
7+
import StackViews: StackView
8+
9+
lazy_stack(x) = StackView(x)
10+
lazy_stack(x::AbstractVector{<:Number}) = x

src/traces.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
export Trace, Traces, MultiplexTraces, Episode, Episodes
22

33
import MacroTools: @forward
4-
import StackViews: StackView
54

65
#####
76

@@ -196,7 +195,7 @@ end
196195

197196
function Base.getindex(e::Episodes{names}, I) where {names}
198197
NamedTuple{names}(
199-
StackView(
198+
lazy_stack(
200199
map(I) do i
201200
x, y = e.inds[i]
202201
e.episodes[x][n][y]

test/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@test length(e) == 5
2323
@test size(e[2:4].state) == (2, 3, 3)
24-
@test_broken size(e[2:4].action) == (3,)
24+
@test size(e[2:4].action) == (3,)
2525
end
2626

2727
@testset "MetaSampler" begin

test/traces.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@ end
167167

168168
@test t[end] == (state=2.0, action=2)
169169

170-
# https://github.com/JuliaArrays/StackViews.jl/issues/3
171-
@test_broken t[1:2] == (state=[1.0, 2.0], action=[1, 2])
170+
@test t[1:2] == (state=[1.0, 2.0], action=[1, 2])
172171

173172
push!(t, (state=3.0, action=3))
174173
t[] = true # seal
175174

175+
# a vector of episode-level partitions is returned for now
176176
@test_broken size(t[:state]) == (3,)
177177

178178
push!(t, Episode(Traces(state=[4.0, 5.0, 6.0], action=[4, 5, 6])))

0 commit comments

Comments
 (0)