|
| 1 | +# reversing |
| 2 | + |
| 3 | +# the kernel works by treating the array as 1d. after reversing by dimension x an element at |
| 4 | +# pos [i1, i2, i3, ... , i{x}, ..., i{n}] will be at |
| 5 | +# pos [i1, i2, i3, ... , d{x} - i{x} + 1, ..., i{n}] where d{x} is the size of dimension x |
| 6 | + |
| 7 | +# out-of-place version, copying a single value per thread from input to output |
| 8 | +function _reverse(input::AnyGPUArray{T, N}, output::AnyGPUArray{T, N}; |
| 9 | + dims=1:ndims(input)) where {T, N} |
| 10 | + @assert size(input) == size(output) |
| 11 | + rev_dims = ntuple((d)-> d in dims && size(input, d) > 1, N) |
| 12 | + ref = size(input) .+ 1 |
| 13 | + # converts an ND-index in the data array to the linear index |
| 14 | + lin_idx = LinearIndices(input) |
| 15 | + # converts a linear index in a reduced array to an ND-index, but using the reduced size |
| 16 | + nd_idx = CartesianIndices(input) |
| 17 | + |
| 18 | + ## COV_EXCL_START |
| 19 | + @kernel unsafe_indices=true function kernel(input, output) |
| 20 | + offset_in = Int32(@groupsize()[1]) * (@index(Group, Linear) - 1i32) |
| 21 | + index_in = offset_in + @index(Local, Linear) |
| 22 | + |
| 23 | + @inbounds if index_in <= length(input) |
| 24 | + idx = Tuple(nd_idx[index_in]) |
| 25 | + idx = ifelse.(rev_dims, ref .- idx, idx) |
| 26 | + index_out = lin_idx[idx...] |
| 27 | + output[index_out] = input[index_in] |
| 28 | + end |
| 29 | + end |
| 30 | + ## COV_EXCL_STOP |
| 31 | + |
| 32 | + nthreads = 256 |
| 33 | + |
| 34 | + kernel(get_backend(input), nthreads)(input, output; ndrange=length(input)) |
| 35 | +end |
| 36 | + |
| 37 | +# in-place version, swapping elements on half the number of threads |
| 38 | +function _reverse!(data::AnyGPUArray{T, N}; dims=1:ndims(data)) where {T, N} |
| 39 | + rev_dims = ntuple((d)-> d in dims && size(data, d) > 1, N) |
| 40 | + half_dim = findlast(rev_dims) |
| 41 | + if isnothing(half_dim) |
| 42 | + # no reverse operation needed at all in this case. |
| 43 | + return |
| 44 | + end |
| 45 | + ref = size(data) .+ 1 |
| 46 | + # converts an ND-index in the data array to the linear index |
| 47 | + lin_idx = LinearIndices(data) |
| 48 | + reduced_size = ntuple((d)->ifelse(d==half_dim, cld(size(data,d),2), size(data,d)), N) |
| 49 | + reduced_length = prod(reduced_size) |
| 50 | + # converts a linear index in a reduced array to an ND-index, but using the reduced size |
| 51 | + nd_idx = CartesianIndices(reduced_size) |
| 52 | + |
| 53 | + ## COV_EXCL_START |
| 54 | + @kernel unsafe_indices=true function kernel(data) |
| 55 | + offset_in = Int32(@groupsize()[1]) * (@index(Group, Linear) - 1i32) |
| 56 | + index_in = offset_in + @index(Local, Linear) |
| 57 | + |
| 58 | + @inbounds if index_in <= reduced_length |
| 59 | + idx = Tuple(nd_idx[index_in]) |
| 60 | + index_in = lin_idx[idx...] |
| 61 | + idx = ifelse.(rev_dims, ref .- idx, idx) |
| 62 | + index_out = lin_idx[idx...] |
| 63 | + |
| 64 | + if index_in < index_out |
| 65 | + temp = data[index_out] |
| 66 | + data[index_out] = data[index_in] |
| 67 | + data[index_in] = temp |
| 68 | + end |
| 69 | + end |
| 70 | + end |
| 71 | + ## COV_EXCL_STOP |
| 72 | + |
| 73 | + # NOTE: we launch slightly more than half the number of elements in the array as threads. |
| 74 | + # The last non-singleton dimension along which to reverse is used to define how the array is split. |
| 75 | + # Only the middle row in case of an odd array dimension could cause trouble, but this is prevented by |
| 76 | + # ignoring the threads that cross the mid-point |
| 77 | + |
| 78 | + nthreads = 256 |
| 79 | + |
| 80 | + kernel(get_backend(data), nthreads)(data; ndrange=length(data)) |
| 81 | +end |
| 82 | + |
| 83 | + |
| 84 | +# n-dimensional API |
| 85 | + |
| 86 | +function Base.reverse!(data::AnyGPUArray{T, N}; dims=:) where {T, N} |
| 87 | + if isa(dims, Colon) |
| 88 | + dims = 1:ndims(data) |
| 89 | + end |
| 90 | + if !applicable(iterate, dims) |
| 91 | + throw(ArgumentError("dimension $dims is not an iterable")) |
| 92 | + end |
| 93 | + if !all(1 .≤ dims .≤ ndims(data)) |
| 94 | + throw(ArgumentError("dimension $dims is not 1 ≤ $dims ≤ $(ndims(data))")) |
| 95 | + end |
| 96 | + |
| 97 | + _reverse!(data; dims=dims) |
| 98 | + |
| 99 | + return data |
| 100 | +end |
| 101 | + |
| 102 | +# out-of-place |
| 103 | +function Base.reverse(input::AnyGPUArray{T, N}; dims=:) where {T, N} |
| 104 | + if isa(dims, Colon) |
| 105 | + dims = 1:ndims(input) |
| 106 | + end |
| 107 | + if !applicable(iterate, dims) |
| 108 | + throw(ArgumentError("dimension $dims is not an iterable")) |
| 109 | + end |
| 110 | + if !all(1 .≤ dims .≤ ndims(input)) |
| 111 | + throw(ArgumentError("dimension $dims is not 1 ≤ $dims ≤ $(ndims(input))")) |
| 112 | + end |
| 113 | + |
| 114 | + if all(size(input)[[dims...]].==1) |
| 115 | + # no reverse operation needed at all in this case. |
| 116 | + return copy(input) |
| 117 | + else |
| 118 | + output = similar(input) |
| 119 | + _reverse(input, output; dims=dims) |
| 120 | + return output |
| 121 | + end |
| 122 | +end |
| 123 | + |
| 124 | + |
| 125 | +# 1-dimensional API |
| 126 | + |
| 127 | +# in-place |
| 128 | +Base.@propagate_inbounds function Base.reverse!(data::AnyGPUVector{T}, start::Integer, |
| 129 | + stop::Integer=length(data)) where {T} |
| 130 | + _reverse!(view(data, start:stop)) |
| 131 | + return data |
| 132 | +end |
| 133 | + |
| 134 | +Base.reverse!(data::AnyGPUVector{T}) where {T} = @inbounds reverse!(data, 1, length(data)) |
| 135 | + |
| 136 | +# out-of-place |
| 137 | +Base.@propagate_inbounds function Base.reverse(input::AnyGPUVector{T}, start::Integer, |
| 138 | + stop::Integer=length(input)) where {T} |
| 139 | + output = similar(input) |
| 140 | + |
| 141 | + start > 1 && copyto!(output, 1, input, 1, start-1) |
| 142 | + _reverse(view(input, start:stop), view(output, start:stop)) |
| 143 | + stop < length(input) && copyto!(output, stop+1, input, stop+1) |
| 144 | + |
| 145 | + return output |
| 146 | +end |
| 147 | + |
| 148 | +Base.reverse(data::AnyGPUVector{T}) where {T} = @inbounds reverse(data, 1, length(data)) |
0 commit comments