Skip to content

Commit ba1b967

Browse files
authored
Merge pull request #31 from JuliaReinforcementLearning/test_on_GPU
add tests on GPU
2 parents 60ff6f6 + b96a9b2 commit ba1b967

File tree

5 files changed

+75
-18
lines changed

5 files changed

+75
-18
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
matrix:
2020
version:
2121
- '1.6'
22-
- '1.7'
22+
- '1'
2323
- 'nightly'
2424
os:
2525
- ubuntu-latest

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
33
version = "0.1.5"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
78
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -14,12 +15,13 @@ StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
1415
CircularArrayBuffers = "0.1"
1516
ElasticArrays = "1"
1617
MacroTools = "0.5"
18+
OnlineStats = "1"
1719
StackViews = "0.1"
1820
julia = "1.6"
19-
OnlineStats = "1"
2021

2122
[extras]
23+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2224
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2325

2426
[targets]
25-
test = ["Test"]
27+
test = ["Test", "CUDA"]

src/traces.jl

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces, Episode, Episodes
33
import MacroTools: @forward
44

55
import CircularArrayBuffers
6+
import Adapt
67

78
#####
89

@@ -13,11 +14,23 @@ Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x
1314
Base.summary(io::IO, t::AbstractTrace) = print(io, "$(length(t))-element $(nameof(typeof(t)))")
1415

1516
#####
17+
18+
"""
19+
Trace(A::AbstractArray)
20+
21+
Similar to
22+
[`Slices`](https://github.com/JuliaLang/julia/blob/master/base/slicearray.jl)
23+
which will be introduced in `[email protected]`. The main difference is that, the
24+
`axes` info in the `Slices` is static, while it may be dynamic with `Trace`.
25+
26+
We only support slices along the last dimension since it's the most common usage
27+
in RL.
28+
"""
1629
struct Trace{T,E} <: AbstractTrace{E}
1730
parent::T
1831
end
1932

20-
Base.summary(io::IO, t::Trace{T}) where {T} = print(io, "$(length(t))-element $(nameof(typeof(t))){$T}")
33+
Base.summary(io::IO, t::Trace{T}) where {T} = print(io, "$(length(t))-element$(length(t) > 0 ? 's' : "") $(nameof(typeof(t))){$T}")
2134

2235
function Trace(x::T) where {T<:AbstractArray}
2336
E = eltype(x)
@@ -27,6 +40,8 @@ function Trace(x::T) where {T<:AbstractArray}
2740
Trace{T,SubArray{E,N,P,I,true}}(x)
2841
end
2942

43+
Adapt.adapt_structure(to, t::Trace) = Trace(Adapt.adapt_structure(to, t.parent))
44+
3045
Base.convert(::Type{AbstractTrace}, x::AbstractArray) = Trace(x)
3146

3247
Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
@@ -59,6 +74,21 @@ Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names
5974

6075
#####
6176

77+
"""
78+
Dedicated for `MultiplexTraces` to avoid scalar indexing when `view(view(t::MultiplexTrace, 1:end-1), I)`.
79+
"""
80+
struct RelativeTrace{left,right,T,E} <: AbstractTrace{E}
81+
trace::Trace{T,E}
82+
end
83+
RelativeTrace{left,right}(t::Trace{T,E}) where {left,right,T,E} = RelativeTrace{left,right,T,E}(t)
84+
85+
Base.size(x::RelativeTrace{0,-1}) = (max(0, length(x.trace) - 1),)
86+
Base.size(x::RelativeTrace{1,0}) = (max(0, length(x.trace) - 1),)
87+
Base.getindex(s::RelativeTrace{0,-1}, I) = getindex(s.trace, I)
88+
Base.getindex(s::RelativeTrace{1,0}, I) = getindex(s.trace, I .+ 1)
89+
Base.setindex!(s::RelativeTrace{0,-1}, v, I) = setindex!(s.trace, v, I)
90+
Base.setindex!(s::RelativeTrace{1,0}, v, I) = setindex!(s.trace, v, I .+ 1)
91+
6292
"""
6393
MultiplexTraces{names}(trace)
6494
@@ -89,12 +119,14 @@ function MultiplexTraces{names}(t) where {names}
89119
MultiplexTraces{names,typeof(trace),eltype(trace)}(trace)
90120
end
91121

122+
Adapt.adapt_structure(to, t::MultiplexTraces{names}) where {names} = MultiplexTraces{names}(Adapt.adapt_structure(to, t.trace))
123+
92124
function Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names}
93125
a, b = names
94126
if k == a
95-
convert(AbstractTrace, t.trace[1:end-1])
127+
RelativeTrace{0,-1}(convert(AbstractTrace, t.trace))
96128
elseif k == b
97-
convert(AbstractTrace, t.trace[2:end])
129+
RelativeTrace{1,0}(convert(AbstractTrace, t.trace))
98130
else
99131
throw(ArgumentError("unknown trace name: $k"))
100132
end
@@ -133,6 +165,8 @@ end
133165

134166
Episode(t::AbstractTraces{names,T}) where {names,T} = Episode{typeof(t),names,T}(t, Ref(false))
135167

168+
Adapt.adapt_structure(to, t::Episode{T,names,E}) where {T,names,E} = Episode{T,names,E}(Adapt.adapt_structure(to, t.traces), t.is_terminated)
169+
136170
@forward Episode.traces Base.getindex, Base.setindex!, Base.size
137171

138172
Base.getindex(e::Episode) = getindex(e.is_terminated)
@@ -175,6 +209,11 @@ struct Episodes{names,E,T} <: AbstractTraces{names,E}
175209
inds::Vector{Tuple{Int,Int}}
176210
end
177211

212+
Adapt.adapt_structure(to, t::Episodes) =
213+
Episodes() do
214+
Adapt.adapt_structure(to, t.init())
215+
end
216+
178217
function Episodes(init)
179218
x = init()
180219
T = typeof(x)
@@ -249,6 +288,11 @@ struct Traces{names,T,N,E} <: AbstractTraces{names,E}
249288
inds::NamedTuple{names,NTuple{N,Int}}
250289
end
251290

291+
function Adapt.adapt_structure(to, t::Traces{names,T,N,E}) where {names,T,N,E}
292+
data = Adapt.adapt_structure(to, t.traces)
293+
# FIXME: `E` is not adapted here
294+
Traces{names,typeof(data),length(names),E}(data, t.inds)
295+
end
252296

253297
function Traces(; kw...)
254298
data = map(x -> convert(AbstractTrace, x), values(kw))

test/common.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,21 @@ end
3131
action=Float32 => (2,),
3232
reward=Float32 => (),
3333
terminal=Bool => ()
34-
)
34+
) |> gpu
3535

3636
@test t isa CircularArraySARTTraces
3737

38-
push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)))
38+
push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu)
3939
@test length(t) == 0
4040

41-
push!(t, (reward=1.0f0, terminal=false))
41+
push!(t, (reward=1.0f0, terminal=false) |> gpu)
4242
@test length(t) == 0 # next_state and next_action is still missing
4343

44-
push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2))
44+
push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu)
4545
@test length(t) == 1
4646

47-
@test t[1] == (
47+
# this will trigger the scalar indexing of CuArray
48+
CUDA.@allowscalar @test t[1] == (
4849
state=ones(Float32, 2, 3),
4950
next_state=ones(Float32, 2, 3) * 2,
5051
action=ones(Float32, 2),
@@ -54,28 +55,30 @@ end
5455
)
5556

5657
push!(t, (reward=2.0f0, terminal=false))
57-
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3))
58+
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu)
5859

5960
@test length(t) == 2
6061

6162
push!(t, (reward=3.0f0, terminal=false))
62-
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4))
63+
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu)
6364

6465
@test length(t) == 3
6566

6667
push!(t, (reward=4.0f0, terminal=false))
67-
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5))
68+
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu)
6869

6970
@test length(t) == 3
70-
@test t[1] == (
71+
72+
# this will trigger the scalar indexing of CuArray
73+
CUDA.@allowscalar @test t[1] == (
7174
state=ones(Float32, 2, 3) * 2,
7275
next_state=ones(Float32, 2, 3) * 3,
7376
action=ones(Float32, 2) * 2,
7477
next_action=ones(Float32, 2) * 3,
7578
reward=2.0f0,
7679
terminal=false,
7780
)
78-
@test t[end] == (
81+
CUDA.@allowscalar @test t[end] == (
7982
state=ones(Float32, 2, 3) * 4,
8083
next_state=ones(Float32, 2, 3) * 5,
8184
action=ones(Float32, 2) * 4,
@@ -87,8 +90,8 @@ end
8790
batch = t[1:3]
8891
@test size(batch.state) == (2, 3, 3)
8992
@test size(batch.action) == (2, 3)
90-
@test batch.reward == [2.0, 3.0, 4.0]
91-
@test batch.terminal == Bool[0, 0, 0]
93+
@test batch.reward == [2.0, 3.0, 4.0] |> gpu
94+
@test batch.terminal == Bool[0, 0, 0] |> gpu
9295
end
9396

9497
@testset "ElasticArraySARTTraces" begin

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
using ReinforcementLearningTrajectories
22
using CircularArrayBuffers
33
using Test
4+
using CUDA
5+
using Adapt
6+
7+
struct TestAdaptor end
8+
9+
gpu(x) = Adapt.adapt(TestAdaptor(), x)
10+
11+
Adapt.adapt_storage(to::TestAdaptor, x) = CUDA.functional() ? CUDA.cu(x) : x
412

513
@testset "ReinforcementLearningTrajectories.jl" begin
614
include("traces.jl")

0 commit comments

Comments
 (0)