diff --git a/Project.toml b/Project.toml index 7920fa2..53c05dc 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ JLArrays = "0.2" KernelAbstractions = "0.9" LinearAlgebra = "1" LinearOperators = "2.3" -LinearOperatorCollection = "2" +LinearOperatorCollection = "2.1" LRUCache = "1.6" MPIFiles = "0.13, 0.14, 0.15, 0.16, 0.17" ProgressMeter = "1.2" diff --git a/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl b/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl index 81ab1e7..621439f 100644 --- a/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl +++ b/ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl @@ -135,6 +135,108 @@ function LinearAlgebra.mul!(res::AbstractVector{T}, adj::Adjoint{T, OP}, t::Abst return res end +function LinearAlgebra.mul!(res::AbstractVector{Tc}, nop::NormalOp{Tc, OP}, x::AbstractVector) where {T, Tc <: Complex{T}, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{Tc, V}} + weights = prepareKernelWeights(T, nop.weights) + return mul_dense_normal!(res, nop, x, weights) +end +# Known weights +prepareKernelWeights(T, weights::WeightingOp) = weights.weights +prepareKernelWeights(::Type{T}, weights::Nothing) where T= one(T) +# Unknown weight, cant do kernel fusion +prepareKernelWeights(T, weights) = nothing + +function mul_dense_normal!(res::AbstractVector{Tc}, nop::NormalOp{Tc, OP}, x::AbstractVector, weights::Nothing) where {T, Tc <: Complex{T}, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{Tc, V}} + op = nop.parent + mul!(nop.tmp, op, x) + mul!(nop.tmp, nop.weights, nop.tmp) + mul!(res, adjoint(op), nop.tmp) +end + +function mul_dense_normal!(res::AbstractVector{Tc}, nop::NormalOp{Tc, OP}, x::AbstractVector, weights) where {T, Tc <: Complex{T}, V <: AbstractGPUArray, OP <: DenseMultiPatchOperator{Tc, V}} + backend = get_backend(res) + op = nop.parent + res .= zero(T) # We need to zero the result, because we are using += in the kernel + + @kernel cpu = false function dense_mul_normal!(res, @Const(x), @Const(S), @Const(xcc), @Const(xss), @Const(signs), @Const(M), @Const(RowToPatch), @Const(patchToSMIdx), @Const(weights)) + ### Forward operator ### + # Each group/block handles a single row of the operator + operator_row = @index(Group, Linear) # k + patch = RowToPatch[operator_row] # p + patch_row = mod1(operator_row, M) # j + smIdx = patchToSMIdx[patch] + sign = eltype(x)(signs[patch_row, smIdx]) + @uniform grid_stride = prod(@groupsize()) + N = Int32(size(xss, 1)) + + # We want to use a grid-stride loop to perform the sparse matrix-vector product. + # Each thread performs a single element-wise multiplication and reduction in its shared spot. + # Afterwards we reduce over the shared memory. + localIdx = @index(Local, Linear) + shared = @localmem eltype(x) grid_stride + shared[localIdx] = zero(eltype(x)) + + # First we iterate over the sparse indices + tmp = zero(eltype(x)) + @unroll for i = localIdx:grid_stride:N + tmp += sign * S[xss[i, patch], patch_row, smIdx] * x[xcc[i, patch]] + end + shared[localIdx] = tmp + # We first sum in a temp variable, hoping that it is accumulated in a register, since registers are faster than shared memory + @synchronize + + # Now we need to reduce the shared memory to get the final result + full_reduction = grid_stride < N + if full_reduction + + # For a full reduction we know s = 512 and can (manually) unroll our loop + #localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512]) + #@synchronize + localIdx <= 256 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+256]) + @synchronize + localIdx <= 128 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+128]) + @synchronize + localIdx <= 64 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+64]) + @synchronize + localIdx <= 32 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+32]) + @synchronize + localIdx <= 16 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+16]) + @synchronize + localIdx <= 8 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+8]) + @synchronize + localIdx <= 4 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+4]) + @synchronize + localIdx <= 2 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+2]) + @synchronize + localIdx == 1 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx+1]) + @synchronize + + + else + @private s = div(min(grid_stride, N), Int32(2)) + while s > Int32(0) + if localIdx <= s + shared[localIdx] = shared[localIdx] + shared[localIdx+s] + end + s >>= 1 + @synchronize + end + end + + ### Adjoint operator ### + val = shared[1] * get_kernel_weights(weights, operator_row) + @unroll for i = localIdx:grid_stride:N + tmp2 = sign * conj(S[xss[i, patch], patch_row, smIdx]) * val + # @atomic is not supported for ComplexF32 numbers + Atomix.@atomic res[1, xcc[i, patch]] += real(tmp2) + Atomix.@atomic res[2, xcc[i, patch]] += imag(tmp2) + end + end + + kernel = dense_mul_normal!(backend, 512, (512, size(op, 1))) + kernel(reinterpret(reshape, T, res), x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx, weights; ndrange = (512, size(op, 1))) + return res +end + # Kaczmarz specific functions function RegularizedLeastSquares.dot_with_matrix_row(op::DenseMultiPatchOperator{T, V}, x::AbstractArray{T}, k::Int) where {T, V <: AbstractGPUArray} patch = @allowscalar op.RowToPatch[k]