Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ JLArrays = "0.2"
KernelAbstractions = "0.9"
LinearAlgebra = "1"
LinearOperators = "2.3"
LinearOperatorCollection = "2"
LinearOperatorCollection = "2.1"
LRUCache = "1.6"
MPIFiles = "0.13, 0.14, 0.15, 0.16, 0.17"
ProgressMeter = "1.2"
Expand Down
102 changes: 102 additions & 0 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,108 @@ function LinearAlgebra.mul!(res::AbstractVector{T}, adj::Adjoint{T, OP}, t::Abst
return res
end

function LinearAlgebra.mul!(res::AbstractVector{Tc}, nop::NormalOp{Tc, OP}, x::AbstractVector) where {T, Tc <: Complex{T}, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{Tc, V}}
weights = prepareKernelWeights(T, nop.weights)
return mul_dense_normal!(res, nop, x, weights)
end
# Known weights
prepareKernelWeights(T, weights::WeightingOp) = weights.weights
prepareKernelWeights(::Type{T}, weights::Nothing) where T= one(T)
# Unknown weight, cant do kernel fusion
prepareKernelWeights(T, weights) = nothing

function mul_dense_normal!(res::AbstractVector{Tc}, nop::NormalOp{Tc, OP}, x::AbstractVector, weights::Nothing) where {T, Tc <: Complex{T}, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{Tc, V}}
op = nop.parent
mul!(nop.tmp, op, x)
mul!(nop.tmp, nop.weights, nop.tmp)
mul!(res, adjoint(op), nop.tmp)
end

function mul_dense_normal!(res::AbstractVector{Tc}, nop::NormalOp{Tc, OP}, x::AbstractVector, weights) where {T, Tc <: Complex{T}, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{Tc, V}}
backend = get_backend(res)
op = nop.parent
res .= zero(T) # We need to zero the result, because we are using += in the kernel

@kernel cpu = false function dense_mul_normal!(res, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx), @Const(weights))
### Forward operator ###
# Each group/block handles a single row of the operator
operator_row = @index(Group, Linear) # k
patch = RowToPatch[operator_row] # p
patch_row = mod1(operator_row, M) # j
smIdx = patchToSMIdx[patch]
sign = eltype(x)(signs[patch_row, smIdx])
@uniform grid_stride = prod(@groupsize())
N = Int32(size(xss, 1))

# We want to use a grid-stride loop to perform the sparse matrix-vector product.
# Each thread performs a single element-wise multiplication and reduction in its shared spot.
# Afterwards we reduce over the shared memory.
localIdx = @index(Local, Linear)
shared = @localmem eltype(x) grid_stride
shared[localIdx] = zero(eltype(x))

# First we iterate over the sparse indices
tmp = zero(eltype(x))
@unroll for i = localIdx:grid_stride:N
tmp += sign * S[xss[i, patch], patch_row, smIdx] * x[xcc[i, patch]]
end
shared[localIdx] = tmp
# We first sum in a temp variable, hoping that it is accumulated in a register, since registers are faster than shared memory
@synchronize

# Now we need to reduce the shared memory to get the final result
full_reduction = grid_stride < N
if full_reduction

# For a full reduction we know s = 512 and can (manually) unroll our loop
#localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512])
#@synchronize
localIdx <= 256 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+256])
@synchronize
localIdx <= 128 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+128])
@synchronize
localIdx <= 64 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+64])
@synchronize
localIdx <= 32 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+32])
@synchronize
localIdx <= 16 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+16])
@synchronize
localIdx <= 8 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+8])
@synchronize
localIdx <= 4 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+4])
@synchronize
localIdx <= 2 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+2])
@synchronize
localIdx == 1 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+1])
@synchronize


else
@private s = div(min(grid_stride, N), Int32(2))
while s > Int32(0)
if localIdx <= s
shared[localIdx] = shared[localIdx] + shared[localIdx+s]
end
s >>= 1
@synchronize
end
end

### Adjoint operator ###
val = shared[1] * get_kernel_weights(weights, operator_row)
@unroll for i = localIdx:grid_stride:N
tmp2 = sign * conj(S[xss[i, patch], patch_row, smIdx]) * val
# @atomic is not supported for ComplexF32 numbers
Atomix.@atomic res[1, xcc[i, patch]] += real(tmp2)
Atomix.@atomic res[2, xcc[i, patch]] += imag(tmp2)
end
end

kernel = dense_mul_normal!(backend, 512, (512, size(op, 1)))
kernel(reinterpret(reshape, T, res), x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx, weights; ndrange = (512, size(op, 1)))
return res
end

# Kaczmarz specific functions
function RegularizedLeastSquares.dot_with_matrix_row(op::DenseMultiPatchOperator{T, V}, x::AbstractArray{T}, k::Int) where {T, V <: AbstractGPUArray}
patch = @allowscalar op.RowToPatch[k]
Expand Down
Loading