Skip to content

Commit 3c666f7

Browse files
committed
[oneMKL] Add support for oneSparseMatrixCSC
1 parent 3d3278d commit 3c666f7

File tree

5 files changed

+336
-85
lines changed

5 files changed

+336
-85
lines changed

lib/mkl/array.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export oneSparseMatrixCSR, oneSparseMatrixCOO
1+
export oneSparseMatrixCSR, oneSparseMatrixCSC, oneSparseMatrixCOO
22

33
abstract type oneAbstractSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
44
const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
@@ -13,6 +13,15 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1313
nnz::Ti
1414
end
1515

16+
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17+
handle::matrix_handle_t
18+
colPtr::oneVector{Ti}
19+
rowVal::oneVector{Ti}
20+
nzVal::oneVector{Tv}
21+
dims::NTuple{2,Int}
22+
nnz::Ti
23+
end
24+
1625
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1726
handle::matrix_handle_t
1827
rowInd::oneVector{Ti}
@@ -37,6 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
3746
SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal
3847

3948
for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
49+
:oneSparseMatrixCSC => :SparseMatrixCSC,
4050
:oneSparseMatrixCOO => :SparseMatrixCSC]
4151
@eval Base.show(io::IOContext, x::$gpu) =
4252
show(io, $cpu(x))

lib/mkl/interfaces.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,35 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
77
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
88
end
99

10+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
11+
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
12+
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13+
end
14+
1015
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
1116
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
1217
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
1318
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
1419
end
1520

16-
for SparseMatrixType in (:oneSparseMatrixCSR,)
17-
@eval begin
18-
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat
19-
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
20-
end
21-
end
21+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
22+
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
23+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
24+
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
25+
end
26+
27+
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where T <: BlasFloat
28+
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
29+
end
30+
31+
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSC{T}, B::oneVector{T}) where T <: BlasReal
32+
sparse_trsv!(flip_uplo(uploc), tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
33+
end
34+
35+
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
36+
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
37+
end
38+
39+
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}) where T <: BlasReal
40+
sparse_trsm!(flip_uplo(uploc), tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
2241
end

lib/mkl/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,6 @@ end
113113
ptrs = pointer.(batch)
114114
return oneArray(ptrs)
115115
end
116+
117+
flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N'
118+
flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L'

lib/mkl/wrappers_sparse.jl

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
3535
A_csc = SparseMatrixCSC(At |> transpose)
3636
return A_csc
3737
end
38+
39+
function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
40+
handle_ptr = Ref{matrix_handle_t}()
41+
onemklXsparse_init_matrix_handle(handle_ptr)
42+
m, n = size(A)
43+
colPtr = oneVector{$intty}(A.colptr)
44+
rowVal = oneVector{$intty}(A.rowval)
45+
nzVal = oneVector{$elty}(A.nzval)
46+
nnzA = length(A.nzval)
47+
queue = global_queue(context(nzVal), device())
48+
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50+
finalizer(sparse_release_matrix_handle, dA)
51+
return dA
52+
end
53+
54+
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
55+
handle_ptr = Ref{matrix_handle_t}()
56+
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
57+
return A_csc
58+
end
3859
end
3960
end
4061

@@ -100,6 +121,33 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
100121
end
101122
end
102123

124+
for SparseMatrix in (:oneSparseMatrixCSC,)
125+
for (fname, elty) in ((:onemklSsparse_gemv, :Float32),
126+
(:onemklDsparse_gemv, :Float64))
127+
@eval begin
128+
function sparse_gemv!(trans::Char,
129+
alpha::Number,
130+
A::$SparseMatrix{$elty},
131+
x::oneStridedVector{$elty},
132+
beta::Number,
133+
y::oneStridedVector{$elty})
134+
135+
queue = global_queue(context(x), device())
136+
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
137+
y
138+
end
139+
end
140+
end
141+
142+
@eval begin
143+
function sparse_optimize_gemv!(trans::Char, A::$SparseMatrix)
144+
queue = global_queue(context(A.nzVal), device(A.nzVal))
145+
onemklXsparse_optimize_gemv(sycl_queue(queue), flip_trans(trans), A.handle)
146+
return A
147+
end
148+
end
149+
end
150+
103151
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
104152
(:onemklDsparse_gemm, :Float64),
105153
(:onemklCsparse_gemm, :ComplexF32),
@@ -139,6 +187,43 @@ function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSpars
139187
return A
140188
end
141189

190+
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
191+
(:onemklDsparse_gemm, :Float64))
192+
@eval begin
193+
function sparse_gemm!(transa::Char,
194+
transb::Char,
195+
alpha::Number,
196+
A::oneSparseMatrixCSC{$elty},
197+
B::oneStridedMatrix{$elty},
198+
beta::Number,
199+
C::oneStridedMatrix{$elty})
200+
201+
mB, nB = size(B)
202+
mC, nC = size(C)
203+
(nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns."))
204+
(mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns."))
205+
nrhs = size(B, 2)
206+
ldb = max(1,stride(B,2))
207+
ldc = max(1,stride(C,2))
208+
queue = global_queue(context(C), device())
209+
$fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
210+
C
211+
end
212+
end
213+
end
214+
215+
function sparse_optimize_gemm!(trans::Char, A::oneSparseMatrixCSC)
216+
queue = global_queue(context(A.nzVal), device(A.nzVal))
217+
onemklXsparse_optimize_gemm(sycl_queue(queue), flip_trans(trans), A.handle)
218+
return A
219+
end
220+
221+
function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSparseMatrixCSC)
222+
queue = global_queue(context(A.nzVal), device(A.nzVal))
223+
onemklXsparse_optimize_gemm_advanced(sycl_queue(queue), 'C', filp_trans(trans), transB, A.handle, nrhs)
224+
return A
225+
end
226+
142227
for (fname, elty) in ((:onemklSsparse_symv, :Float32),
143228
(:onemklDsparse_symv, :Float64),
144229
(:onemklCsparse_symv, :ComplexF32),
@@ -158,6 +243,24 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
158243
end
159244
end
160245

246+
247+
for (fname, elty) in ((:onemklSsparse_symv, :Float32),
248+
(:onemklDsparse_symv, :Float64))
249+
@eval begin
250+
function sparse_symv!(uplo::Char,
251+
alpha::Number,
252+
A::oneSparseMatrixCSC{$elty},
253+
x::oneStridedVector{$elty},
254+
beta::Number,
255+
y::oneStridedVector{$elty})
256+
257+
queue = global_queue(context(y), device())
258+
$fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
259+
y
260+
end
261+
end
262+
end
263+
161264
for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
162265
(:onemklDsparse_trmv, :Float64),
163266
(:onemklCsparse_trmv, :ComplexF32),
@@ -185,6 +288,31 @@ function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
185288
return A
186289
end
187290

291+
for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
292+
(:onemklDsparse_trmv, :Float64))
293+
@eval begin
294+
function sparse_trmv!(uplo::Char,
295+
trans::Char,
296+
diag::Char,
297+
alpha::Number,
298+
A::oneSparseMatrixCSC{$elty},
299+
x::oneStridedVector{$elty},
300+
beta::Number,
301+
y::oneStridedVector{$elty})
302+
303+
queue = global_queue(context(y), device())
304+
$fname(sycl_queue(queue), flip_uplo(uplo), trans, diag, alpha, A.handle, x, beta, y)
305+
y
306+
end
307+
end
308+
end
309+
310+
function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
311+
queue = global_queue(context(A.nzVal), device(A.nzVal))
312+
onemklXsparse_optimize_trmv(sycl_queue(queue), flip_uplo(uplo), trans, diag, A.handle)
313+
return A
314+
end
315+
188316
for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
189317
(:onemklDsparse_trsv, :Float64),
190318
(:onemklCsparse_trsv, :ComplexF32),
@@ -211,6 +339,30 @@ function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparse
211339
return A
212340
end
213341

342+
for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
343+
(:onemklDsparse_trsv, :Float64))
344+
@eval begin
345+
function sparse_trsv!(uplo::Char,
346+
trans::Char,
347+
diag::Char,
348+
alpha::Number,
349+
A::oneSparseMatrixCSC{$elty},
350+
x::oneStridedVector{$elty},
351+
y::oneStridedVector{$elty})
352+
353+
queue = global_queue(context(y), device())
354+
$fname(sycl_queue(queue), filp_uplo(uplo), trans, diag, alpha, A.handle, x, y)
355+
y
356+
end
357+
end
358+
end
359+
360+
function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
361+
queue = global_queue(context(A.nzVal), device(A.nzVal))
362+
onemklXsparse_optimize_trsv(sycl_queue(queue), flip_uplo(uplo), trans, diag, A.handle)
363+
return A
364+
end
365+
214366
for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
215367
(:onemklDsparse_trsm, :Float64),
216368
(:onemklCsparse_trsm, :ComplexF32),
@@ -252,3 +404,43 @@ function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A
252404
onemklXsparse_optimize_trsm_advanced(sycl_queue(queue), 'C', uplo, trans, diag, A.handle, nrhs)
253405
return A
254406
end
407+
408+
for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
409+
(:onemklDsparse_trsm, :Float64))
410+
@eval begin
411+
function sparse_trsm!(uplo::Char,
412+
transA::Char,
413+
transX::Char,
414+
diag::Char,
415+
alpha::Number,
416+
A::oneSparseMatrixCSC{$elty},
417+
X::oneStridedMatrix{$elty},
418+
Y::oneStridedMatrix{$elty})
419+
420+
mX, nX = size(X)
421+
mY, nY = size(Y)
422+
(mX != mY) && (transX == 'N') && throw(ArgumentError("X and Y must have the same number of rows."))
423+
(nX != nY) && (transX == 'N') && throw(ArgumentError("X and Y must have the same number of columns."))
424+
(nX != mY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of rows."))
425+
(mX != nY) && (transX != 'N') && throw(ArgumentError("Xᵀ and Y must have the same number of columns."))
426+
nrhs = size(X, 2)
427+
ldx = max(1,stride(X,2))
428+
ldy = max(1,stride(Y,2))
429+
queue = global_queue(context(Y), device())
430+
$fname(sycl_queue(queue), 'C', transA, transX, flip_uplo(uplo), diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
431+
Y
432+
end
433+
end
434+
end
435+
436+
function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSC)
437+
queue = global_queue(context(A.nzVal), device(A.nzVal))
438+
onemklXsparse_optimize_trsm(sycl_queue(queue), filp_uplo(uplo), trans, diag, A.handle)
439+
return A
440+
end
441+
442+
function sparse_optimize_trsm!(uplo::Char, trans::Char, diag::Char, nrhs::Int, A::oneSparseMatrixCSR)
443+
queue = global_queue(context(A.nzVal), device(A.nzVal))
444+
onemklXsparse_optimize_trsm_advanced(sycl_queue(queue), 'C', flip_uplo(uplo), trans, diag, A.handle, nrhs)
445+
return A
446+
end

0 commit comments

Comments
 (0)