Skip to content

Commit f0af6a2

Browse files
committed
rename some variables to avoid confusion
1 parent 1ff0d8f commit f0af6a2

File tree

4 files changed

+25
-25
lines changed

4 files changed

+25
-25
lines changed

src/common/CircularArraySARTTraces.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
export CircularArraySARTTraces
22

33
const CircularArraySARTTraces = Traces{
4-
SSAART,
4+
SS′AA′RT,
55
<:Tuple{
6-
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7-
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
6+
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7+
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
88
<:Trace{<:CircularArrayBuffer},
99
<:Trace{<:CircularArrayBuffer},
1010
}
@@ -22,8 +22,8 @@ function CircularArraySARTTraces(;
2222
reward_eltype, reward_size = reward
2323
terminal_eltype, terminal_size = terminal
2424

25-
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
26-
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
25+
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
26+
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
2727
Traces(
2828
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
2929
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),

src/common/CircularArraySLARTTraces.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
export CircularArraySLARTTraces
22

33
const CircularArraySLARTTraces = Traces{
4-
SSLLAART,
4+
SS′LL′AA′RT,
55
<:Tuple{
6-
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7-
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
8-
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
6+
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
7+
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
8+
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
99
<:Trace{<:CircularArrayBuffer},
1010
<:Trace{<:CircularArrayBuffer},
1111
}
@@ -25,9 +25,9 @@ function CircularArraySLARTTraces(;
2525
reward_eltype, reward_size = reward
2626
terminal_eltype, terminal_size = terminal
2727

28-
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
29-
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
30-
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
28+
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
29+
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
30+
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
3131
Traces(
3232
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
3333
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),

src/common/common.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
export SS, LL, AA, RT, SSART, SSAART, SSLLAART
1+
export SS, LL, AA, RT, SS′ART, SS′AA′RT, SS′L′ART, SS′LL′AA′RT
22

33
using CircularArrayBuffers
44

5-
const SS = (:state, :next_state)
6-
const LL = (:legal_actions_mask, :next_legal_actions_mask)
7-
const AA = (:action, :next_action)
5+
const SS = (:state, :next_state)
6+
const LL = (:legal_actions_mask, :next_legal_actions_mask)
7+
const AA = (:action, :next_action)
88
const RT = (:reward, :terminal)
9-
const SSART = (SS..., :action, RT...)
10-
const SSAART = (SS..., AA..., RT...)
11-
const SSLART = (SS..., :legal_actions_mask, :action, RT...)
12-
const SSLLAART = (SS..., LL..., AA..., RT...)
9+
const SS′ART = (SS..., :action, RT...)
10+
const SS′AA′RT = (SS..., AA..., RT...)
11+
const SS′L′ART = (SS..., :next_legal_actions_mask, :action, RT...)
12+
const SS′LL′AA′RT = (SS..., LL..., AA..., RT...)
1313

1414
include("sum_tree.jl")
1515
include("CircularArraySARTTraces.jl")

src/samplers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
3131

3232
function sample(s::BatchSampler, t::AbstractTraces, names)
3333
inds = rand(s.rng, 1:length(t), s.batch_size)
34-
NamedTuple{names}(t[x][inds] for x in names)
34+
NamedTuple{names}(map(x -> t[x][inds], names))
3535
end
3636

3737
#####
@@ -108,7 +108,7 @@ function sample(s::NStepBatchSampler{names}, ts) where {names}
108108
sample(s, ts, Val(names), inds)
109109
end
110110

111-
function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds)
111+
function sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
112112
if isnothing(nbs.stack_size)
113113
s = ts[:state][inds]
114114
s′ = ts[:next_state][inds.+(nbs.n-1)]
@@ -129,11 +129,11 @@ function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds)
129129
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
130130
end
131131

132-
NamedTuple{SSART}((s, s′, a, r, t))
132+
NamedTuple{SS′ART}((s, s′, a, r, t))
133133
end
134134

135-
function sample(s::NStepBatchSampler, ts, ::Val{SSLART}, inds)
135+
function sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
136136
s, s′, a, r, t = sample(s, ts, Val(SSART), inds)
137-
l = consecutive_view(ts[:legal_actions_mask], inds)
137+
l = consecutive_view(ts[:next_legal_actions_mask], inds)
138138
NamedTuple{SSLART}((s, s′, l, a, r, t))
139139
end

0 commit comments

Comments
 (0)