Skip to content

Commit b79926a

Browse files
authored
[oneMKL] Add support for oneSparseMatrixCSC (#526)
1 parent 04d8080 commit b79926a

File tree

5 files changed

+342
-87
lines changed

5 files changed

+342
-87
lines changed

lib/mkl/array.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export oneSparseMatrixCSR, oneSparseMatrixCOO
1+
export oneSparseMatrixCSR, oneSparseMatrixCSC, oneSparseMatrixCOO
22

33
abstract type oneAbstractSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
44
const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
@@ -13,6 +13,15 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1313
nnz::Ti
1414
end
1515

16+
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17+
handle::matrix_handle_t
18+
colPtr::oneVector{Ti}
19+
rowVal::oneVector{Ti}
20+
nzVal::oneVector{Tv}
21+
dims::NTuple{2,Int}
22+
nnz::Ti
23+
end
24+
1625
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1726
handle::matrix_handle_t
1827
rowInd::oneVector{Ti}
@@ -37,6 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
3746
SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal
3847

3948
for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
49+
:oneSparseMatrixCSC => :SparseMatrixCSC,
4050
:oneSparseMatrixCOO => :SparseMatrixCSC]
4151
@eval Base.show(io::IOContext, x::$gpu) =
4252
show(io, $cpu(x))

lib/mkl/interfaces.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,27 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
77
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
88
end
99

10+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
11+
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
12+
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13+
end
14+
1015
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
1116
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
1217
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
1318
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
1419
end
1520

16-
for SparseMatrixType in (:oneSparseMatrixCSR,)
17-
@eval begin
18-
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat
19-
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
20-
end
21-
end
21+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
22+
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
23+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
24+
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
25+
end
26+
27+
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where T <: BlasFloat
28+
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
29+
end
30+
31+
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
32+
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
2233
end

lib/mkl/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,6 @@ end
113113
ptrs = pointer.(batch)
114114
return oneArray(ptrs)
115115
end
116+
117+
flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N'
118+
flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L'

lib/mkl/wrappers_sparse.jl

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
3535
A_csc = SparseMatrixCSC(At |> transpose)
3636
return A_csc
3737
end
38+
39+
function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
40+
handle_ptr = Ref{matrix_handle_t}()
41+
onemklXsparse_init_matrix_handle(handle_ptr)
42+
m, n = size(A)
43+
colPtr = oneVector{$intty}(A.colptr)
44+
rowVal = oneVector{$intty}(A.rowval)
45+
nzVal = oneVector{$elty}(A.nzval)
46+
nnzA = length(A.nzval)
47+
queue = global_queue(context(nzVal), device())
48+
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50+
finalizer(sparse_release_matrix_handle, dA)
51+
return dA
52+
end
53+
54+
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
55+
handle_ptr = Ref{matrix_handle_t}()
56+
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
57+
return A_csc
58+
end
3859
end
3960
end
4061

@@ -100,6 +121,33 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
100121
end
101122
end
102123

124+
for SparseMatrix in (:oneSparseMatrixCSC,)
125+
for (fname, elty) in ((:onemklSsparse_gemv, :Float32),
126+
(:onemklDsparse_gemv, :Float64))
127+
@eval begin
128+
function sparse_gemv!(trans::Char,
129+
alpha::Number,
130+
A::$SparseMatrix{$elty},
131+
x::oneStridedVector{$elty},
132+
beta::Number,
133+
y::oneStridedVector{$elty})
134+
135+
queue = global_queue(context(x), device())
136+
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
137+
y
138+
end
139+
end
140+
end
141+
142+
@eval begin
143+
function sparse_optimize_gemv!(trans::Char, A::$SparseMatrix)
144+
queue = global_queue(context(A.nzVal), device(A.nzVal))
145+
onemklXsparse_optimize_gemv(sycl_queue(queue), flip_trans(trans), A.handle)
146+
return A
147+
end
148+
end
149+
end
150+
103151
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
104152
(:onemklDsparse_gemm, :Float64),
105153
(:onemklCsparse_gemm, :ComplexF32),
@@ -139,6 +187,43 @@ function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSpars
139187
return A
140188
end
141189

190+
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
191+
(:onemklDsparse_gemm, :Float64))
192+
@eval begin
193+
function sparse_gemm!(transa::Char,
194+
transb::Char,
195+
alpha::Number,
196+
A::oneSparseMatrixCSC{$elty},
197+
B::oneStridedMatrix{$elty},
198+
beta::Number,
199+
C::oneStridedMatrix{$elty})
200+
201+
mB, nB = size(B)
202+
mC, nC = size(C)
203+
(nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns."))
204+
(mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns."))
205+
nrhs = size(B, 2)
206+
ldb = max(1,stride(B,2))
207+
ldc = max(1,stride(C,2))
208+
queue = global_queue(context(C), device())
209+
$fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
210+
C
211+
end
212+
end
213+
end
214+
215+
function sparse_optimize_gemm!(trans::Char, A::oneSparseMatrixCSC)
216+
queue = global_queue(context(A.nzVal), device(A.nzVal))
217+
onemklXsparse_optimize_gemm(sycl_queue(queue), flip_trans(trans), A.handle)
218+
return A
219+
end
220+
221+
function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSparseMatrixCSC)
222+
queue = global_queue(context(A.nzVal), device(A.nzVal))
223+
onemklXsparse_optimize_gemm_advanced(sycl_queue(queue), 'C', flip_trans(trans), transB, A.handle, nrhs)
224+
return A
225+
end
226+
142227
for (fname, elty) in ((:onemklSsparse_symv, :Float32),
143228
(:onemklDsparse_symv, :Float64),
144229
(:onemklCsparse_symv, :ComplexF32),
@@ -158,6 +243,23 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
158243
end
159244
end
160245

246+
for (fname, elty) in ((:onemklSsparse_symv, :Float32),
247+
(:onemklDsparse_symv, :Float64))
248+
@eval begin
249+
function sparse_symv!(uplo::Char,
250+
alpha::Number,
251+
A::oneSparseMatrixCSC{$elty},
252+
x::oneStridedVector{$elty},
253+
beta::Number,
254+
y::oneStridedVector{$elty})
255+
256+
queue = global_queue(context(y), device())
257+
$fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
258+
y
259+
end
260+
end
261+
end
262+
161263
for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
162264
(:onemklDsparse_trmv, :Float64),
163265
(:onemklCsparse_trmv, :ComplexF32),
@@ -185,6 +287,34 @@ function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
185287
return A
186288
end
187289

290+
# Only trans = 'N' is supported with oneSparseMatrixCSR.
291+
# We can't use any trick to support sparse "trmv" for oneSparseMatrixCSC.
292+
#
293+
# for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
294+
# (:onemklDsparse_trmv, :Float64))
295+
# @eval begin
296+
# function sparse_trmv!(uplo::Char,
297+
# trans::Char,
298+
# diag::Char,
299+
# alpha::Number,
300+
# A::oneSparseMatrixCSC{$elty},
301+
# x::oneStridedVector{$elty},
302+
# beta::Number,
303+
# y::oneStridedVector{$elty})
304+
#
305+
# queue = global_queue(context(y), device())
306+
# $fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y)
307+
# y
308+
# end
309+
# end
310+
# end
311+
#
312+
# function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
313+
# queue = global_queue(context(A.nzVal), device(A.nzVal))
314+
# onemklXsparse_optimize_trmv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
315+
# return A
316+
# end
317+
188318
for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
189319
(:onemklDsparse_trsv, :Float64),
190320
(:onemklCsparse_trsv, :ComplexF32),
@@ -211,6 +341,33 @@ function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
211341
return A
212342
end
213343

344+
# Only trans = 'N' is supported with oneSparseMatrixCSR.
345+
# We can't use any trick to support sparse "trsv" for oneSparseMatrixCSC.
346+
#
347+
# for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
348+
# (:onemklDsparse_trsv, :Float64))
349+
# @eval begin
350+
# function sparse_trsv!(uplo::Char,
351+
# trans::Char,
352+
# diag::Char,
353+
# alpha::Number,
354+
# A::oneSparseMatrixCSC{$elty},
355+
# x::oneStridedVector{$elty},
356+
# y::oneStridedVector{$elty})
357+
#
358+
# queue = global_queue(context(y), device())
359+
# $fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, y)
360+
# y
361+
# end
362+
# end
363+
# end
364+
#
365+
# function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
366+
# queue = global_queue(context(A.nzVal), device(A.nzVal))
367+
# onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
368+
# return A
369+
# end
370+
214371
for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
215372
(:onemklDsparse_trsm, :Float64),
216373
(:onemklCsparse_trsm, :ComplexF32),
@@ -252,3 +409,46 @@ function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A
252409
onemklXsparse_optimize_trsm_advanced(sycl_queue(queue), 'C', uplo, trans, diag, A.handle, nrhs)
253410
return A
254411
end
412+
413+
# Only transA = 'N' is supported with oneSparseMatrixCSR.
414+
# We can't use any trick to support sparse "trsm" for oneSparseMatrixCSC.
415+
#
416+
# for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
417+
# (:onemklDsparse_trsm, :Float64))
418+
# @eval begin
419+
# function sparse_trsm!(uplo::Char,
420+
# transA::Char,
421+
# transX::Char,
422+
# diag::Char,
423+
# alpha::Number,
424+
# A::oneSparseMatrixCSC{$elty},
425+
# X::oneStridedMatrix{$elty},
426+
# Y::oneStridedMatrix{$elty})
427+
#
428+
# mX, nX = size(X)
429+
# mY, nY = size(Y)
430+
# (mX != mY) && (transX == 'N') && throw(ArgumentError("X and Y must have the same number of rows."))
431+
# (nX != nY) && (transX == 'N') && throw(ArgumentError("X and Y must have the same number of columns."))
432+
# (nX != mY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of rows."))
433+
# (mX != nY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of columns."))
434+
# nrhs = size(X, 2)
435+
# ldx = max(1,stride(X,2))
436+
# ldy = max(1,stride(Y,2))
437+
# queue = global_queue(context(Y), device())
438+
# $fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
439+
# Y
440+
# end
441+
# end
442+
# end
443+
#
444+
# function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
445+
# queue = global_queue(context(A.nzVal), device(A.nzVal))
446+
# onemklXsparse_optimize_trsm(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
447+
# return A
448+
# end
449+
#
450+
# function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A::oneSparseMatrixCSC)
451+
# queue = global_queue(context(A.nzVal), device(A.nzVal))
452+
# onemklXsparse_optimize_trsm_advanced(sycl_queue(queue), 'C', uplo, flip_trans(trans), diag, A.handle, nrhs)
453+
# return A
454+
# end

0 commit comments

Comments
 (0)