Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion lib/mkl/array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export oneSparseMatrixCSR, oneSparseMatrixCOO
export oneSparseMatrixCSR, oneSparseMatrixCSC, oneSparseMatrixCOO

abstract type oneAbstractSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
Expand All @@ -13,6 +13,15 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
nnz::Ti
end

mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
colPtr::oneVector{Ti}
rowVal::oneVector{Ti}
nzVal::oneVector{Tv}
dims::NTuple{2,Int}
nnz::Ti
end

mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
rowInd::oneVector{Ti}
Expand All @@ -37,6 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal

for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
:oneSparseMatrixCSC => :SparseMatrixCSC,
:oneSparseMatrixCOO => :SparseMatrixCSC]
@eval Base.show(io::IOContext, x::$gpu) =
show(io, $cpu(x))
Expand Down
23 changes: 17 additions & 6 deletions lib/mkl/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,27 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

for SparseMatrixType in (:oneSparseMatrixCSR,)
@eval begin
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
end
end
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where T <: BlasFloat
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
end

function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
end
3 changes: 3 additions & 0 deletions lib/mkl/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ end
ptrs = pointer.(batch)
return oneArray(ptrs)
end

flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N'
flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L'
200 changes: 200 additions & 0 deletions lib/mkl/wrappers_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
A_csc = SparseMatrixCSC(At |> transpose)
return A_csc
end

function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
onemklXsparse_init_matrix_handle(handle_ptr)
m, n = size(A)
colPtr = oneVector{$intty}(A.colptr)
rowVal = oneVector{$intty}(A.rowval)
nzVal = oneVector{$elty}(A.nzval)
nnzA = length(A.nzval)
queue = global_queue(context(nzVal), device())
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
return dA
end

function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
return A_csc
end
end
end

Expand Down Expand Up @@ -100,6 +121,33 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
end
end

for SparseMatrix in (:oneSparseMatrixCSC,)
for (fname, elty) in ((:onemklSsparse_gemv, :Float32),
(:onemklDsparse_gemv, :Float64))
@eval begin
function sparse_gemv!(trans::Char,
alpha::Number,
A::$SparseMatrix{$elty},
x::oneStridedVector{$elty},
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(x), device())
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
y
end
end
end

@eval begin
function sparse_optimize_gemv!(trans::Char, A::$SparseMatrix)
queue = global_queue(context(A.nzVal), device(A.nzVal))
onemklXsparse_optimize_gemv(sycl_queue(queue), flip_trans(trans), A.handle)
return A
end
end
end

for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
(:onemklDsparse_gemm, :Float64),
(:onemklCsparse_gemm, :ComplexF32),
Expand Down Expand Up @@ -139,6 +187,43 @@ function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSpars
return A
end

for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
(:onemklDsparse_gemm, :Float64))
@eval begin
function sparse_gemm!(transa::Char,
transb::Char,
alpha::Number,
A::oneSparseMatrixCSC{$elty},
B::oneStridedMatrix{$elty},
beta::Number,
C::oneStridedMatrix{$elty})

mB, nB = size(B)
mC, nC = size(C)
(nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns."))
(mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns."))
nrhs = size(B, 2)
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(C), device())
$fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
C
end
end
end

function sparse_optimize_gemm!(trans::Char, A::oneSparseMatrixCSC)
queue = global_queue(context(A.nzVal), device(A.nzVal))
onemklXsparse_optimize_gemm(sycl_queue(queue), flip_trans(trans), A.handle)
return A
end

function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSparseMatrixCSC)
queue = global_queue(context(A.nzVal), device(A.nzVal))
onemklXsparse_optimize_gemm_advanced(sycl_queue(queue), 'C', flip_trans(trans), transB, A.handle, nrhs)
return A
end

for (fname, elty) in ((:onemklSsparse_symv, :Float32),
(:onemklDsparse_symv, :Float64),
(:onemklCsparse_symv, :ComplexF32),
Expand All @@ -158,6 +243,23 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
end
end

for (fname, elty) in ((:onemklSsparse_symv, :Float32),
(:onemklDsparse_symv, :Float64))
@eval begin
function sparse_symv!(uplo::Char,
alpha::Number,
A::oneSparseMatrixCSC{$elty},
x::oneStridedVector{$elty},
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(y), device())
$fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
y
end
end
end

for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
(:onemklDsparse_trmv, :Float64),
(:onemklCsparse_trmv, :ComplexF32),
Expand Down Expand Up @@ -185,6 +287,34 @@ function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
return A
end

# Only trans = 'N' is supported with oneSparseMatrixCSR.
# We can't use any trick to support sparse "trmv" for oneSparseMatrixCSC.
#
# for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
# (:onemklDsparse_trmv, :Float64))
# @eval begin
# function sparse_trmv!(uplo::Char,
# trans::Char,
# diag::Char,
# alpha::Number,
# A::oneSparseMatrixCSC{$elty},
# x::oneStridedVector{$elty},
# beta::Number,
# y::oneStridedVector{$elty})
#
# queue = global_queue(context(y), device())
# $fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y)
# y
# end
# end
# end
#
# function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
# queue = global_queue(context(A.nzVal), device(A.nzVal))
# onemklXsparse_optimize_trmv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
# return A
# end

for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
(:onemklDsparse_trsv, :Float64),
(:onemklCsparse_trsv, :ComplexF32),
Expand All @@ -211,6 +341,33 @@ function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
return A
end

# Only trans = 'N' is supported with oneSparseMatrixCSR.
# We can't use any trick to support sparse "trsv" for oneSparseMatrixCSC.
#
# for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
# (:onemklDsparse_trsv, :Float64))
# @eval begin
# function sparse_trsv!(uplo::Char,
# trans::Char,
# diag::Char,
# alpha::Number,
# A::oneSparseMatrixCSC{$elty},
# x::oneStridedVector{$elty},
# y::oneStridedVector{$elty})
#
# queue = global_queue(context(y), device())
# $fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, y)
# y
# end
# end
# end
#
# function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
# queue = global_queue(context(A.nzVal), device(A.nzVal))
# onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
# return A
# end

for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
(:onemklDsparse_trsm, :Float64),
(:onemklCsparse_trsm, :ComplexF32),
Expand Down Expand Up @@ -252,3 +409,46 @@ function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A
onemklXsparse_optimize_trsm_advanced(sycl_queue(queue), 'C', uplo, trans, diag, A.handle, nrhs)
return A
end

# Only transA = 'N' is supported with oneSparseMatrixCSR.
# We can't use any trick to support sparse "trsm" for oneSparseMatrixCSC.
#
# for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
# (:onemklDsparse_trsm, :Float64))
# @eval begin
# function sparse_trsm!(uplo::Char,
# transA::Char,
# transX::Char,
# diag::Char,
# alpha::Number,
# A::oneSparseMatrixCSC{$elty},
# X::oneStridedMatrix{$elty},
# Y::oneStridedMatrix{$elty})
#
# mX, nX = size(X)
# mY, nY = size(Y)
# (mX != mY) && (transX == 'N') && throw(ArgumentError("X and Y must have the same number of rows."))
# (nX != nY) && (transX == 'N') && throw(ArgumentError("X and Y must have the same number of columns."))
# (nX != mY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of rows."))
# (mX != nY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of columns."))
# nrhs = size(X, 2)
# ldx = max(1,stride(X,2))
# ldy = max(1,stride(Y,2))
# queue = global_queue(context(Y), device())
# $fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
# Y
# end
# end
# end
#
# function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
# queue = global_queue(context(A.nzVal), device(A.nzVal))
# onemklXsparse_optimize_trsm(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
# return A
# end
#
# function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A::oneSparseMatrixCSC)
# queue = global_queue(context(A.nzVal), device(A.nzVal))
# onemklXsparse_optimize_trsm_advanced(sycl_queue(queue), 'C', uplo, flip_trans(trans), diag, A.handle, nrhs)
# return A
# end
Loading