Skip to content

Commit 03489a0

Browse files
authored
Merge pull request #56 from JuliaReinforcementLearning/nstep
NStepBatchSampler
2 parents c89ed6f + b190c60 commit 03489a0

File tree

6 files changed

+133
-169
lines changed

6 files changed

+133
-169
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.3.4"
3+
version = "0.3.5"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/common/sum_tree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ Random.rand(rng::AbstractRNG, t::SumTree{T}) where {T} = get(t, rand(rng, T) * t
184184
Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t)
185185

186186
function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
187-
inds, priorities = Vector{Int}(undef, n), Vector{Float64}(undef, n)
187+
inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n)
188188
for i in 1:n
189189
v = (i - 1 + rand(rng, T)) / n
190190
ind, p = get(t, v * t.tree[1])

src/samplers.jl

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -163,75 +163,92 @@ end
163163

164164
export NStepBatchSampler
165165

166-
mutable struct NStepBatchSampler{traces}
166+
"""
167+
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)
168+
169+
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
170+
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
171+
that in up to `n > 1` steps later in the buffer. The reward will be
172+
the discounted sum of the `n` rewards, with `γ` as the discount factor.
173+
174+
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size` is set
175+
to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
176+
of partial observability, for example when the state is approximated by `stack_size` consecutive
177+
frames.
178+
"""
179+
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}}
167180
n::Int # !!! n starts from 1
168181
γ::Float32
169182
batch_size::Int
170-
stack_size::Union{Nothing,Int}
183+
stack_size::S
171184
rng::Any
172185
end
173186

174-
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
175-
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
176-
187+
NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
188+
function NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names}
189+
@assert n >= 1 "n must be ≥ 1."
190+
ss = stack_size == 1 ? nothing : stack_size
191+
NStepBatchSampler{names, typeof(ss)}(n, γ, batch_size, ss, rng)
192+
end
177193

178-
function valid_range_nbatchsampler(s::NStepBatchSampler, ts)
179-
# think about the extreme case where s.stack_size == 1 and s.n == 1
180-
isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))
194+
#return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
195+
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
196+
range = copy(eb.sampleable_inds)
197+
ns = Vector{Int}(undef, length(eb.sampleable_inds))
198+
stack_size = isnothing(s.stack_size) ? 1 : s.stack_size
199+
for idx in eachindex(range)
200+
step_number = eb.step_numbers[idx]
201+
range[idx] = step_number >= stack_size && eb.sampleable_inds[idx]
202+
ns[idx] = min(s.n, eb.episodes_lengths[idx] - step_number + 1)
203+
end
204+
return range, ns
181205
end
206+
182207
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
183-
valid_range = valid_range_nbatchsampler(s, ts)
184-
inds = rand(s.rng, valid_range, s.batch_size)
185-
StatsBase.sample(s, ts, Val(names), inds)
208+
StatsBase.sample(s, ts, Val(names))
186209
end
187210

188-
function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
189-
valid_range = valid_range_nbatchsampler(s, ts)
190-
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
191-
inds = rand(s.rng, valid_range, s.batch_size)
192-
StatsBase.sample(s, ts, Val(names), inds)
211+
function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names
212+
weights, ns = valid_range(s, t)
213+
inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batch_size)
214+
fetch(s, t, Val(names), inds, ns)
193215
end
194216

195-
196-
function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
197-
if isnothing(nbs.stack_size)
198-
s = ts[:state][inds]
199-
s′ = ts[:next_state][inds.+(nbs.n-1)]
200-
else
201-
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
202-
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
203-
end
204-
205-
a = ts[:action][inds]
206-
t_horizon = ts[:terminal][[x + j for j in 0:nbs.n-1, x in inds]]
207-
r_horizon = ts[:reward][[x + j for j in 0:nbs.n-1, x in inds]]
208-
209-
@assert ndims(t_horizon) == 2
210-
t = any(t_horizon, dims=1) |> vec
211-
212-
@assert ndims(r_horizon) == 2
213-
r = map(eachcol(r_horizon), eachcol(t_horizon)) do r⃗, t⃗
214-
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
215-
end
216-
217-
NamedTuple{SS′ART}(map(collect, (s, s′, a, r, t)))
217+
function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names
218+
NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names))
218219
end
219220

220-
function StatsBase.sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
221-
s, s′, a, r, t = StatsBase.sample(s, ts, Val(SSART), inds)
222-
l = consecutive_view(ts[:next_legal_actions_mask], inds)
223-
NamedTuple{SSLART}(map(collect, (s, s′, l, a, r, t)))
221+
#state and next_state have specialized fetch methods due to stack_size
222+
fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds]
223+
fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stack_size+1:0, x in inds]]
224+
fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1]
225+
fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stack_size+1:0, (idx,x) in enumerate(inds)]]
226+
227+
#reward due to discounting
228+
function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns)
229+
rewards = Vector{eltype(trace)}(undef, length(inds))
230+
for (i,idx) in enumerate(inds)
231+
rewards_to_go = trace[idx:idx+ns[i]-1]
232+
rewards[i] = foldr((x,y)->x + s.γ*y, rewards_to_go)
233+
end
234+
return rewards
224235
end
236+
#terminal is that of the nth step
237+
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val{:terminal}, inds, ns) = trace[inds .+ ns .- 1]
238+
#right multiplex traces must be n-step sampled
239+
fetch(::NStepBatchSampler, trace::RelativeTrace{1,0} , ::Val, inds, ns) = trace[inds .+ ns .- 1]
240+
#normal trace types are fetched at inds
241+
fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds] #other types of trace are sampled normally
225242

226243
function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names}
227244
t = e.traces
228245
st = deepcopy(t.priorities)
229-
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
246+
valids, ns = valid_range(s,e)
247+
st .*= valids[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
230248
inds, priorities = rand(s.rng, st, s.batch_size)
231-
232249
merge(
233250
(key=t.keys[inds], priority=priorities),
234-
StatsBase.sample(s, t.traces, Val(names), inds)
251+
fetch(s, e, Val(names), inds, ns)
235252
)
236253
end
237254

src/traces.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),)
4848
Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
4949
Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...)
5050

51-
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!
51+
@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, Base.eltype
5252

5353
#By default, AbstractTrace have infinity capacity (like a Vector). This method is specialized for
5454
#CircularArraySARTSTraces in common.jl. The functions below are made that way to avoid type piracy.
@@ -94,6 +94,7 @@ Base.getindex(s::RelativeTrace{0,-1}, I) = getindex(s.trace, I)
9494
Base.getindex(s::RelativeTrace{1,0}, I) = getindex(s.trace, I .+ 1)
9595
Base.setindex!(s::RelativeTrace{0,-1}, v, I) = setindex!(s.trace, v, I)
9696
Base.setindex!(s::RelativeTrace{1,0}, v, I) = setindex!(s.trace, v, I .+ 1)
97+
Base.eltype(t::RelativeTrace) = eltype(t.trace)
9798
capacity(t::RelativeTrace) = capacity(t.trace)
9899

99100
"""

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ReinforcementLearningTrajectories
22
using CircularArrayBuffers, DataStructures
33
using StableRNGs
44
using Test
5+
import ReinforcementLearningTrajectories.StatsBase.sample
56
using CUDA
67
using Adapt
78
using Random

test/samplers.jl

Lines changed: 65 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ReinforcementLearningTrajectories.fetch
12
@testset "Samplers" begin
23
@testset "BatchSampler" begin
34
sz = 32
@@ -74,132 +75,76 @@
7475

7576
#! format: off
7677
@testset "NStepSampler" begin
77-
γ = 0.9
78+
γ = 0.99
7879
n_stack = 2
7980
n_horizon = 3
80-
batch_size = 4
81-
82-
t1 = MultiplexTraces{(:state, :next_state)}(1:10) +
83-
MultiplexTraces{(:action, :next_action)}(iseven.(1:10)) +
84-
Traces(
85-
reward=1:9,
86-
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
87-
)
88-
89-
s1 = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
90-
91-
xs = RLTrajectories.StatsBase.sample(s1, t1)
92-
93-
@test size(xs.state) == (n_stack, batch_size)
94-
@test size(xs.next_state) == (n_stack, batch_size)
95-
@test size(xs.action) == (batch_size,)
96-
@test size(xs.reward) == (batch_size,)
97-
@test size(xs.terminal) == (batch_size,)
98-
99-
100-
state_size = (2,3)
101-
n_state = reduce(*, state_size)
102-
total_length = 10
103-
t2 = MultiplexTraces{(:state, :next_state)}(
104-
reshape(1:n_state * total_length, state_size..., total_length)
105-
) +
106-
MultiplexTraces{(:action, :next_action)}(iseven.(1:total_length)) +
107-
Traces(
108-
reward=1:total_length-1,
109-
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
110-
)
111-
112-
xs2 = RLTrajectories.StatsBase.sample(s1, t2)
113-
114-
@test size(xs2.state) == (state_size..., n_stack, batch_size)
115-
@test size(xs2.next_state) == (state_size..., n_stack, batch_size)
116-
@test size(xs2.action) == (batch_size,)
117-
@test size(xs2.reward) == (batch_size,)
118-
@test size(xs2.terminal) == (batch_size,)
119-
120-
inds = [3, 5, 7]
121-
xs3 = RLTrajectories.StatsBase.sample(s1, t2, Val(SS′ART), inds)
122-
123-
@test xs3.state == cat(
124-
(
125-
reshape(n_state * (i-n_stack)+1: n_state * i, state_size..., n_stack)
126-
for i in inds
127-
)...
128-
;dims=length(state_size) + 2
129-
)
130-
131-
@test xs3.next_state == xs3.state .+ (n_state * n_horizon)
132-
@test xs3.action == iseven.(inds)
133-
@test xs3.terminal == [any(t2[:terminal][i: i+n_horizon-1]) for i in inds]
134-
135-
# manual calculation
136-
@test xs3.reward[1] 3 + γ * 4 # terminated at step 4
137-
@test xs3.reward[2] 5 + γ * (6 + γ * 7)
138-
@test xs3.reward[3] 7 + γ * (8 + γ * 9)
139-
end
140-
#! format: on
141-
142-
@testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin
143-
n=1
144-
γ=0.99f0
145-
146-
t = Trajectory(
147-
container=CircularPrioritizedTraces(
148-
CircularArraySARTSTraces(
149-
capacity=5,
150-
state=Float32 => (4,),
151-
);
152-
default_priority=100.0f0
153-
),
154-
sampler=NStepBatchSampler{SS′ART}(
155-
n=n,
156-
γ=γ,
157-
batch_size=32,
158-
),
159-
controller=InsertSampleRatioController(
160-
threshold=100,
161-
n_inserted=-1
162-
)
163-
)
81+
batch_size = 1000
82+
eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10))
83+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
16484

165-
push!(t, (state = 1, action = true))
166-
for i = 1:9
167-
push!(t, (state = i+1, action = true, reward = i, terminal = false))
85+
push!(eb, (state = 1, action = 1))
86+
for i = 1:5
87+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
16888
end
169-
170-
b = RLTrajectories.StatsBase.sample(t)
171-
@test haskey(b, :priority)
172-
@test sum(b.action .== 0) == 0
173-
end
174-
175-
176-
@testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin
177-
n=1
178-
γ=0.99f0
179-
180-
t = Trajectory(
181-
container=CircularArraySARTSTraces(
182-
capacity=5,
183-
state=Float32 => (4,),
184-
),
185-
sampler=NStepBatchSampler{SS′ART}(
186-
n=n,
187-
γ=γ,
188-
batch_size=32,
189-
),
190-
controller=InsertSampleRatioController(
191-
threshold=100,
192-
n_inserted=-1
193-
)
194-
)
195-
196-
push!(t, (state = 1, action = true))
197-
for i = 1:9
198-
push!(t, (state = i+1, action = true, reward = i, terminal = false))
89+
push!(eb, (state = 7, action = 7))
90+
for (j,i) = enumerate(8:11)
91+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
92+
end
93+
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
94+
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
95+
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
96+
inds = [i for i in eachindex(weights) if weights[i] == 1]
97+
batch = sample(s1, eb)
98+
for key in keys(eb)
99+
@test haskey(batch, key)
199100
end
101+
#state: samples with stack_size
102+
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
103+
@test states == [1 2 3 4 7 8 9;
104+
2 3 4 5 8 9 10]
105+
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
106+
#next_state: samples with stack_size and nsteps forward
107+
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
108+
@test next_states == [4 5 5 5 10 10 10;
109+
5 6 6 6 11 11 11]
110+
@test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state])))
111+
#action: samples normally
112+
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
113+
@test actions == inds
114+
@test all(in(actions), unique(batch[:action]))
115+
#next_action: is a multiplex trace: should automatically sample nsteps forward
116+
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
117+
@test next_actions == [5, 6, 6, 6, 11, 11, 11]
118+
@test all(in(next_actions), unique(batch[:next_action]))
119+
#reward: discounted sum
120+
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
121+
@test rewards [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10]
122+
@test all(in(rewards), unique(batch[:reward]))
123+
#terminal: nsteps forward
124+
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
125+
@test terminals == [0,1,1,1,0,0,0]
126+
127+
### CircularPrioritizedTraces and NStepBatchSampler
128+
γ = 0.99
129+
n_horizon = 3
130+
batch_size = 4
131+
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
132+
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size)
200133

201-
b = RLTrajectories.StatsBase.sample(t)
202-
@test sum(b.action .== 0) == 0
134+
push!(eb, (state = 1, action = 1))
135+
for i = 1:5
136+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
137+
end
138+
push!(eb, (state = 7, action = 7))
139+
for (j,i) = enumerate(8:11)
140+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
141+
end
142+
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
143+
inds = [i for i in eachindex(weights) if weights[i] == 1]
144+
batch = sample(s1, eb)
145+
for key in (keys(eb)..., :key, :priority)
146+
@test haskey(batch, key)
147+
end
203148
end
204149

205150
@testset "EpisodesSampler" begin

0 commit comments

Comments
 (0)