Skip to content

Commit 43e514f

Browse files
committed
Port reverse from CUDA
1 parent 3be4a09 commit 43e514f

File tree

4 files changed

+197
-0
lines changed

4 files changed

+197
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "11.2.5"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
8+
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
89
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
910
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -24,6 +25,7 @@ JLD2Ext = "JLD2"
2425
[compat]
2526
Adapt = "4.0"
2627
GPUArraysCore = "= 0.2.0"
28+
GPUToolbox = "0.2, 0.3"
2729
JLD2 = "0.4, 0.5"
2830
KernelAbstractions = "0.9.28"
2931
LLVM = "3.9, 4, 5, 6, 7, 8, 9"

src/GPUArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module GPUArrays
22

3+
using GPUToolbox
34
using KernelAbstractions
45
using Serialization
56
using Random
@@ -26,6 +27,7 @@ include("host/construction.jl")
2627
## integrations and specialized methods
2728
include("host/base.jl")
2829
include("host/indexing.jl")
30+
include("host/reverse.jl")
2931
include("host/broadcast.jl")
3032
include("host/mapreduce.jl")
3133
include("host/linalg.jl")

src/host/reverse.jl

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# reversing
2+
3+
# the kernel works by treating the array as 1d. after reversing by dimension x an element at
4+
# pos [i1, i2, i3, ... , i{x}, ..., i{n}] will be at
5+
# pos [i1, i2, i3, ... , d{x} - i{x} + 1, ..., i{n}] where d{x} is the size of dimension x
6+
7+
# out-of-place version, copying a single value per thread from input to output
8+
function _reverse(input::AnyGPUArray{T, N}, output::AnyGPUArray{T, N};
9+
dims=1:ndims(input)) where {T, N}
10+
@assert size(input) == size(output)
11+
rev_dims = ntuple((d)-> d in dims && size(input, d) > 1, N)
12+
ref = size(input) .+ 1
13+
# converts an ND-index in the data array to the linear index
14+
lin_idx = LinearIndices(input)
15+
# converts a linear index in a reduced array to an ND-index, but using the reduced size
16+
nd_idx = CartesianIndices(input)
17+
18+
## COV_EXCL_START
19+
@kernel unsafe_indices=true function kernel(input, output)
20+
offset_in = Int32(@groupsize()[1]) * (@index(Group, Linear) - 1i32)
21+
index_in = offset_in + @index(Local, Linear)
22+
23+
@inbounds if index_in <= length(input)
24+
idx = Tuple(nd_idx[index_in])
25+
idx = ifelse.(rev_dims, ref .- idx, idx)
26+
index_out = lin_idx[idx...]
27+
output[index_out] = input[index_in]
28+
end
29+
end
30+
## COV_EXCL_STOP
31+
32+
nthreads = 256
33+
34+
kernel(get_backend(input), nthreads)(input, output; ndrange=length(input))
35+
end
36+
37+
# in-place version, swapping elements on half the number of threads
38+
function _reverse!(data::AnyGPUArray{T, N}; dims=1:ndims(data)) where {T, N}
39+
rev_dims = ntuple((d)-> d in dims && size(data, d) > 1, N)
40+
half_dim = findlast(rev_dims)
41+
if isnothing(half_dim)
42+
# no reverse operation needed at all in this case.
43+
return
44+
end
45+
ref = size(data) .+ 1
46+
# converts an ND-index in the data array to the linear index
47+
lin_idx = LinearIndices(data)
48+
reduced_size = ntuple((d)->ifelse(d==half_dim, cld(size(data,d),2), size(data,d)), N)
49+
reduced_length = prod(reduced_size)
50+
# converts a linear index in a reduced array to an ND-index, but using the reduced size
51+
nd_idx = CartesianIndices(reduced_size)
52+
53+
## COV_EXCL_START
54+
@kernel unsafe_indices=true function kernel(data)
55+
offset_in = Int32(@groupsize()[1]) * (@index(Group, Linear) - 1i32)
56+
index_in = offset_in + @index(Local, Linear)
57+
58+
@inbounds if index_in <= reduced_length
59+
idx = Tuple(nd_idx[index_in])
60+
index_in = lin_idx[idx...]
61+
idx = ifelse.(rev_dims, ref .- idx, idx)
62+
index_out = lin_idx[idx...]
63+
64+
if index_in < index_out
65+
temp = data[index_out]
66+
data[index_out] = data[index_in]
67+
data[index_in] = temp
68+
end
69+
end
70+
end
71+
## COV_EXCL_STOP
72+
73+
# NOTE: we launch slightly more than half the number of elements in the array as threads.
74+
# The last non-singleton dimension along which to reverse is used to define how the array is split.
75+
# Only the middle row in case of an odd array dimension could cause trouble, but this is prevented by
76+
# ignoring the threads that cross the mid-point
77+
78+
nthreads = 256
79+
80+
kernel(get_backend(data), nthreads)(data; ndrange=length(data))
81+
end
82+
83+
84+
# n-dimensional API
85+
86+
function Base.reverse!(data::AnyGPUArray{T, N}; dims=:) where {T, N}
87+
if isa(dims, Colon)
88+
dims = 1:ndims(data)
89+
end
90+
if !applicable(iterate, dims)
91+
throw(ArgumentError("dimension $dims is not an iterable"))
92+
end
93+
if !all(1 .≤ dims .≤ ndims(data))
94+
throw(ArgumentError("dimension $dims is not 1 ≤ $dims$(ndims(data))"))
95+
end
96+
97+
_reverse!(data; dims=dims)
98+
99+
return data
100+
end
101+
102+
# out-of-place
103+
function Base.reverse(input::AnyGPUArray{T, N}; dims=:) where {T, N}
104+
if isa(dims, Colon)
105+
dims = 1:ndims(input)
106+
end
107+
if !applicable(iterate, dims)
108+
throw(ArgumentError("dimension $dims is not an iterable"))
109+
end
110+
if !all(1 .≤ dims .≤ ndims(input))
111+
throw(ArgumentError("dimension $dims is not 1 ≤ $dims$(ndims(input))"))
112+
end
113+
114+
if all(size(input)[[dims...]].==1)
115+
# no reverse operation needed at all in this case.
116+
return copy(input)
117+
else
118+
output = similar(input)
119+
_reverse(input, output; dims=dims)
120+
return output
121+
end
122+
end
123+
124+
125+
# 1-dimensional API
126+
127+
# in-place
128+
Base.@propagate_inbounds function Base.reverse!(data::AnyGPUVector{T}, start::Integer,
129+
stop::Integer=length(data)) where {T}
130+
_reverse!(view(data, start:stop))
131+
return data
132+
end
133+
134+
Base.reverse!(data::AnyGPUVector{T}) where {T} = @inbounds reverse!(data, 1, length(data))
135+
136+
# out-of-place
137+
Base.@propagate_inbounds function Base.reverse(input::AnyGPUVector{T}, start::Integer,
138+
stop::Integer=length(input)) where {T}
139+
output = similar(input)
140+
141+
start > 1 && copyto!(output, 1, input, 1, start-1)
142+
_reverse(view(input, start:stop), view(output, start:stop))
143+
stop < length(input) && copyto!(output, stop+1, input, stop+1)
144+
145+
return output
146+
end
147+
148+
Base.reverse(data::AnyGPUVector{T}) where {T} = @inbounds reverse(data, 1, length(data))

test/testsuite/base.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,51 @@ end
381381
gA = reshape(AT(A),4)
382382
end
383383

384+
@testset "reverse" begin
385+
# 1-d out-of-place
386+
@test compare(x->reverse(x), AT, rand(Float32, 1000))
387+
@test compare(x->reverse(x, 10), AT, rand(Float32, 1000))
388+
@test compare(x->reverse(x, 10, 90), AT, rand(Float32, 1000))
389+
390+
# 1-d in-place
391+
@test compare(x->reverse!(x), AT, rand(Float32, 1000))
392+
@test compare(x->reverse!(x, 10), AT, rand(Float32, 1000))
393+
@test compare(x->reverse!(x, 10, 90), AT, rand(Float32, 1000))
394+
395+
# n-d out-of-place
396+
for shape in ([1, 2, 4, 3], [4, 2], [5], [2^5, 2^5, 2^5]),
397+
dim in 1:length(shape)
398+
@test compare(x->reverse(x; dims=dim), AT, rand(Float32, shape...))
399+
400+
cpu = rand(Float32, shape...)
401+
gpu = AT(cpu)
402+
reverse!(gpu; dims=dim)
403+
@test Array(gpu) == reverse(cpu; dims=dim)
404+
end
405+
406+
# supports multidimensional reverse
407+
for shape in ([1, 2, 4, 3], [2^5, 2^5, 2^5]),
408+
dim in ((1,2),(2,3),(1,3),:)
409+
@test compare(x->reverse(x; dims=dim), AT, rand(Float32, shape...))
410+
411+
cpu = rand(Float32, shape...)
412+
gpu = AT(cpu)
413+
reverse!(gpu; dims=dim)
414+
@test Array(gpu) == reverse(cpu; dims=dim)
415+
end
416+
417+
# wrapped array
418+
@test compare(x->reverse(x), AT, reshape(rand(Float32, 2,2), 4))
419+
420+
# error throwing
421+
cpu = rand(Float32, 1,2,3,4)
422+
gpu = AT(cpu)
423+
@test_throws ArgumentError reverse!(gpu, dims=5)
424+
@test_throws ArgumentError reverse!(gpu, dims=0)
425+
@test_throws ArgumentError reverse(gpu, dims=5)
426+
@test_throws ArgumentError reverse(gpu, dims=0)
427+
end
428+
384429
@testset "reinterpret" begin
385430
A = Int32[-1,-2,-3]
386431
dA = AT(A)

0 commit comments

Comments
 (0)