Skip to content

Commit aebbe0e

Browse files
authored
Merge branch 'main' into MultiStepSampler
2 parents 9986fcc + e55919c commit aebbe0e

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

src/samplers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7676
w = StatsBase.FrequencyWeights(p)
7777
w .*= e.sampleable_inds[1:end-1]
7878
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
79+
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batch_size)
7980
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
8081
end
8182

@@ -165,6 +166,7 @@ end
165166
export NStepBatchSampler
166167

167168
"""
169+
168170
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
169171
170172
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Random
99
import ReinforcementLearningTrajectories.StatsBase.sample
1010
import StatsBase.countmap
1111

12+
1213
struct TestAdaptor end
1314

1415
gpu(x) = Adapt.adapt(TestAdaptor(), x)

0 commit comments

Comments
 (0)