Skip to content

Commit 67b944b

Browse files
Merge pull request #22 from JuliaReinforcementLearning/jpsl/patch
2 parents 020a842 + 2faa73d commit 67b944b

File tree

3 files changed

+32
-13
lines changed

3 files changed

+32
-13
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
name = "CircularArrayBuffers"
22
uuid = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
33
authors = ["Jun Tian <[email protected]> and contributors"]
4-
version = "0.1.13"
4+
version = "0.1.14"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89

910
[compat]
1011
Adapt = "2, 3, 4"
1112
julia = "1"
1213

1314
[extras]
14-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1515
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617

1718
[targets]
1819
test = ["CUDA", "Test"]

src/CircularArrayBuffers.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,42 @@ capacity(cb::CircularArrayBuffer{T,N}) where {T,N} = size(cb.buffer, N)
7676
isfull(cb::CircularArrayBuffer) = cb.nframes == capacity(cb)
7777
Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0
7878

79+
"""
80+
_buffer_index(cb::CircularArrayBuffer, i::Int)
81+
82+
Return the index of the `i`-th element in the buffer.
83+
"""
7984
@inline function _buffer_index(cb::CircularArrayBuffer, i::Int)
80-
ind = (cb.first - 1) * cb.step_size + i
81-
if ind > length(cb.buffer)
82-
ind - length(cb.buffer)
85+
idx = (cb.first - 1) * cb.step_size + i
86+
return wrap_index(idx, length(cb.buffer))
87+
end
88+
@inline _buffer_index(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(Base.Fix1(_buffer_index, cb), I)
89+
90+
"""
91+
wrap_index(idx, n)
92+
93+
Return the index of the `idx`-th element in the buffer, if index is one past the size, return 1, else error.
94+
"""
95+
function wrap_index(idx, n)
96+
if idx <= n
97+
return idx
98+
elseif idx <= 2n
99+
return idx - n
83100
else
84-
ind
101+
@info "oops! idx $(idx) > 2n $(2n)"
102+
return idx - n
85103
end
86104
end
87-
@inline _buffer_index(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(Base.Fix1(_buffer_index, cb), I)
88105

106+
"""
107+
_buffer_frame(cb::CircularArrayBuffer, i::Int)
108+
109+
Return the index of the `i`-th frame in the buffer.
110+
"""
89111
@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int)
90112
n = capacity(cb)
91113
idx = cb.first + i - 1
92-
if idx > n
93-
idx - n
94-
else
95-
idx
96-
end
114+
return wrap_index(idx, n)
97115
end
98116

99117
_buffer_frame(cb::CircularArrayBuffer, I::CartesianIndex) = CartesianIndex(map(i->_buffer_frame(cb, i), Tuple(I)))

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ if CUDA.functional()
231231
@test isempty(b) == true
232232
@test length(b) == 0
233233
@test size(b) == (0,)
234-
# element must has the exact same length with the element of buffer
234+
# element must have the exact same length with the element of buffer
235235
@test_throws Exception push!(b, [1, 2])
236236

237237
for x in 1:3

0 commit comments

Comments
 (0)