Skip to content

Commit 919a2b7

Browse files
amontoisonmichel2323
authored andcommitted
[oneMKL] Add support for oneSparseMatrixCSC
1 parent 3d3278d commit 919a2b7

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

lib/mkl/array.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1313
nnz::Ti
1414
end
1515

16+
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17+
handle::matrix_handle_t
18+
colPtr::oneVector{Ti}
19+
rowVal::oneVector{Ti}
20+
nzVal::oneVector{Tv}
21+
dims::NTuple{2,Int}
22+
nnz::Ti
23+
end
24+
1625
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1726
handle::matrix_handle_t
1827
rowInd::oneVector{Ti}
@@ -37,6 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
3746
SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal
3847

3948
for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
49+
:oneSparseMatrixCSC => :SparseMatrixCSC,
4050
:oneSparseMatrixCOO => :SparseMatrixCSC]
4151
@eval Base.show(io::IOContext, x::$gpu) =
4252
show(io, $cpu(x))

lib/mkl/interfaces.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,23 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
77
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
88
end
99

10+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
11+
tA = tA in ('S', 's', 'H', 'h') ? 'T' : (tA == 'N' ? 'T' : 'N')
12+
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13+
end
14+
1015
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
1116
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
1217
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
1318
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
1419
end
1520

21+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
22+
tA = tA in ('S', 's', 'H', 'h') ? 'T' : (tA == 'N' ? 'T' : 'N')
23+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
24+
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
25+
end
26+
1627
for SparseMatrixType in (:oneSparseMatrixCSR,)
1728
@eval begin
1829
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat

lib/mkl/wrappers_sparse.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
3535
A_csc = SparseMatrixCSC(At |> transpose)
3636
return A_csc
3737
end
38+
39+
function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
40+
handle_ptr = Ref{matrix_handle_t}()
41+
onemklXsparse_init_matrix_handle(handle_ptr)
42+
m, n = size(A)
43+
colPtr = oneVector{$intty}(A.colptr)
44+
rowVal = oneVector{$intty}(A.rowval)
45+
nzVal = oneVector{$elty}(A.nzval)
46+
nnzA = length(A.nzval)
47+
queue = global_queue(context(nzVal), device())
48+
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50+
finalizer(sparse_release_matrix_handle, dA)
51+
return dA
52+
end
53+
54+
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
55+
handle_ptr = Ref{matrix_handle_t}()
56+
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
57+
return A_csc
58+
end
3859
end
3960
end
4061

test/onemkl.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,17 @@ end
10881088
end
10891089
end
10901090

1091+
@testset "oneSparseMatrixCSC" begin
1092+
(T isa Complex) && continue
1093+
for S in (Int32, Int64)
1094+
A = sprand(T, 20, 10, 0.5)
1095+
A = SparseMatrixCSC{T, S}(A)
1096+
B = oneSparseMatrixCSC(A)
1097+
A2 = SparseMatrixCSC(B)
1098+
@test A == A2
1099+
end
1100+
end
1101+
10911102
@testset "oneSparseMatrixCOO" begin
10921103
for S in (Int32, Int64)
10931104
A = sprand(T, 20, 10, 0.5)

0 commit comments

Comments
 (0)