You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names}
177
-
valid_range =isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1
182
+
valid_range =valid_range_nbatchsampler(s, ts)
178
183
inds =rand(s.rng, valid_range, s.batch_size)
179
184
StatsBase.sample(s, ts, Val(names), inds)
180
185
end
181
186
187
+
function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
188
+
valid_range =valid_range_nbatchsampler(s, ts)
189
+
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
190
+
inds =rand(s.rng, valid_range, s.batch_size)
191
+
StatsBase.sample(s, ts, Val(names), inds)
192
+
end
193
+
194
+
182
195
function StatsBase.sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
0 commit comments